26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
44 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
67 static IntegerOverflowFlagsAttr
69 IntegerOverflowFlagsAttr val2) {
71 val1.getValue() & val2.getValue());
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
98 llvm_unreachable(
"unknown cmpi predicate kind");
107 static llvm::RoundingMode
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
121 llvm_unreachable(
"Unhandled rounding mode");
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
162 #include "ArithCanonicalization.inc"
172 if (
auto shapedType = dyn_cast<ShapedType>(type))
173 return shapedType.cloneWith(std::nullopt, i1Type);
174 if (llvm::isa<UnrankedTensorType>(type))
183 void arith::ConstantOp::getAsmResultNames(
186 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
187 auto intType = dyn_cast<IntegerType>(type);
190 if (intType && intType.getWidth() == 1)
191 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
195 llvm::raw_svector_ostream specialName(specialNameBuffer);
196 specialName <<
'c' << intCst.getValue();
198 specialName <<
'_' << type;
199 setNameFn(getResult(), specialName.str());
201 setNameFn(getResult(),
"cst");
210 if (llvm::isa<IntegerType>(type) &&
211 !llvm::cast<IntegerType>(type).isSignless())
212 return emitOpError(
"integer return type must be signless");
214 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
216 "value must be an integer, float, or elements attribute");
222 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
224 "intializing scalable vectors with elements attribute is not supported"
225 " unless it's a vector splat");
229 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
231 auto typedAttr = dyn_cast<TypedAttr>(value);
232 if (!typedAttr || typedAttr.getType() != type)
235 if (llvm::isa<IntegerType>(type) &&
236 !llvm::cast<IntegerType>(type).isSignless())
239 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
244 if (isBuildableWith(value, type))
245 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
249 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
252 int64_t value,
unsigned width) {
254 arith::ConstantOp::build(builder, result, type,
263 build(builder, state, value, width);
264 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
265 assert(result &&
"builder didn't return the right type");
272 return create(builder, builder.
getLoc(), value, width);
276 Type type, int64_t value) {
277 arith::ConstantOp::build(builder, result, type,
285 build(builder, state, type, value);
286 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
287 assert(result &&
"builder didn't return the right type");
292 Type type, int64_t value) {
293 return create(builder, builder.
getLoc(), type, value);
297 Type type,
const APInt &value) {
298 arith::ConstantOp::build(builder, result, type,
304 const APInt &value) {
306 build(builder, state, type, value);
307 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
308 assert(result &&
"builder didn't return the right type");
314 const APInt &value) {
315 return create(builder, builder.
getLoc(), type, value);
319 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
320 return constOp.getType().isSignlessInteger();
325 FloatType type,
const APFloat &value) {
326 arith::ConstantOp::build(builder, result, type,
333 const APFloat &value) {
335 build(builder, state, type, value);
336 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
337 assert(result &&
"builder didn't return the right type");
343 const APFloat &value) {
344 return create(builder, builder.
getLoc(), type, value);
348 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
349 return llvm::isa<FloatType>(constOp.getType());
355 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
363 build(builder, state, value);
364 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
365 assert(result &&
"builder didn't return the right type");
371 return create(builder, builder.
getLoc(), value);
375 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
376 return constOp.getType().isIndex();
384 "type doesn't have a zero representation");
386 assert(zeroAttr &&
"unsupported type for zero attribute");
387 return arith::ConstantOp::create(builder, loc, zeroAttr);
400 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
401 if (getRhs() == sub.getRhs())
405 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
406 if (getLhs() == sub.getRhs())
409 return constFoldBinaryOp<IntegerAttr>(
410 adaptor.getOperands(),
411 [](APInt a,
const APInt &b) { return std::move(a) + b; });
416 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
417 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
424 std::optional<SmallVector<int64_t, 4>>
425 arith::AddUIExtendedOp::getShapeForUnroll() {
426 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
427 return llvm::to_vector<4>(vt.getShape());
434 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
438 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
440 Type overflowTy = getOverflow().getType();
446 results.push_back(getLhs());
447 results.push_back(falseValue);
455 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
456 adaptor.getOperands(),
457 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
458 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
459 ArrayRef({sumAttr, adaptor.getLhs()}),
465 results.push_back(sumAttr);
466 results.push_back(overflowAttr);
473 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
475 patterns.add<AddUIExtendedToAddI>(context);
484 if (getOperand(0) == getOperand(1)) {
485 auto shapedType = dyn_cast<ShapedType>(
getType());
487 if (!shapedType || shapedType.hasStaticShape())
494 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
496 if (getRhs() == add.getRhs())
499 if (getRhs() == add.getLhs())
503 return constFoldBinaryOp<IntegerAttr>(
504 adaptor.getOperands(),
505 [](APInt a,
const APInt &b) { return std::move(a) - b; });
510 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
511 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
512 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
529 return constFoldBinaryOp<IntegerAttr>(
530 adaptor.getOperands(),
531 [](
const APInt &a,
const APInt &b) { return a * b; });
534 void arith::MulIOp::getAsmResultNames(
536 if (!isa<IndexType>(
getType()))
542 return op && op->getName().getStringRef() ==
"vector.vscale";
545 IntegerAttr baseValue;
548 isVscale(b.getDefiningOp());
551 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
556 llvm::raw_svector_ostream specialName(specialNameBuffer);
557 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
558 setNameFn(getResult(), specialName.str());
563 patterns.add<MulIMulIConstant>(context);
570 std::optional<SmallVector<int64_t, 4>>
571 arith::MulSIExtendedOp::getShapeForUnroll() {
572 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
573 return llvm::to_vector<4>(vt.getShape());
578 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
583 results.push_back(zero);
584 results.push_back(zero);
589 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
590 adaptor.getOperands(),
591 [](
const APInt &a,
const APInt &b) { return a * b; })) {
593 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
594 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
595 return llvm::APIntOps::mulhs(a, b);
597 assert(highAttr &&
"Unexpected constant-folding failure");
599 results.push_back(lowAttr);
600 results.push_back(highAttr);
607 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
609 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
616 std::optional<SmallVector<int64_t, 4>>
617 arith::MulUIExtendedOp::getShapeForUnroll() {
618 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
619 return llvm::to_vector<4>(vt.getShape());
624 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
629 results.push_back(zero);
630 results.push_back(zero);
638 results.push_back(getLhs());
639 results.push_back(zero);
644 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
645 adaptor.getOperands(),
646 [](
const APInt &a,
const APInt &b) { return a * b; })) {
648 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
649 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
650 return llvm::APIntOps::mulhu(a, b);
652 assert(highAttr &&
"Unexpected constant-folding failure");
654 results.push_back(lowAttr);
655 results.push_back(highAttr);
662 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
664 patterns.add<MulUIExtendedToMulI>(context);
673 arith::IntegerOverflowFlags ovfFlags) {
675 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
678 if (mul.getLhs() == rhs)
681 if (mul.getRhs() == rhs)
687 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
693 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
698 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
699 [&](APInt a,
const APInt &b) {
727 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
733 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
737 bool overflowOrDiv0 =
false;
738 auto result = constFoldBinaryOp<IntegerAttr>(
739 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
740 if (overflowOrDiv0 || !b) {
741 overflowOrDiv0 = true;
744 return a.sdiv_ov(b, overflowOrDiv0);
747 return overflowOrDiv0 ?
Attribute() : result;
774 APInt one(a.getBitWidth(), 1,
true);
775 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
776 return val.sadd_ov(one, overflow);
783 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
788 bool overflowOrDiv0 =
false;
789 auto result = constFoldBinaryOp<IntegerAttr>(
790 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
791 if (overflowOrDiv0 || !b) {
792 overflowOrDiv0 = true;
795 APInt quotient = a.udiv(b);
798 APInt one(a.getBitWidth(), 1,
true);
799 return quotient.uadd_ov(one, overflowOrDiv0);
802 return overflowOrDiv0 ?
Attribute() : result;
813 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
821 bool overflowOrDiv0 =
false;
822 auto result = constFoldBinaryOp<IntegerAttr>(
823 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
824 if (overflowOrDiv0 || !b) {
825 overflowOrDiv0 = true;
831 unsigned bits = a.getBitWidth();
833 bool aGtZero = a.sgt(zero);
834 bool bGtZero = b.sgt(zero);
835 if (aGtZero && bGtZero) {
842 bool overflowNegA =
false;
843 bool overflowNegB =
false;
844 bool overflowDiv =
false;
845 bool overflowNegRes =
false;
846 if (!aGtZero && !bGtZero) {
848 APInt posA = zero.ssub_ov(a, overflowNegA);
849 APInt posB = zero.ssub_ov(b, overflowNegB);
851 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
854 if (!aGtZero && bGtZero) {
856 APInt posA = zero.ssub_ov(a, overflowNegA);
857 APInt div = posA.sdiv_ov(b, overflowDiv);
858 APInt res = zero.ssub_ov(div, overflowNegRes);
859 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
863 APInt posB = zero.ssub_ov(b, overflowNegB);
864 APInt div = a.sdiv_ov(posB, overflowDiv);
865 APInt res = zero.ssub_ov(div, overflowNegRes);
867 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
871 return overflowOrDiv0 ?
Attribute() : result;
882 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
888 bool overflowOrDiv =
false;
889 auto result = constFoldBinaryOp<IntegerAttr>(
890 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
892 overflowOrDiv = true;
895 return a.sfloordiv_ov(b, overflowOrDiv);
898 return overflowOrDiv ?
Attribute() : result;
905 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
912 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
913 [&](APInt a,
const APInt &b) {
914 if (div0 || b.isZero()) {
928 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
935 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
936 [&](APInt a,
const APInt &b) {
937 if (div0 || b.isZero()) {
953 for (
bool reversePrev : {
false,
true}) {
954 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
955 .getDefiningOp<arith::AndIOp>();
959 Value other = (reversePrev ? op.getLhs() : op.getRhs());
960 if (other != prev.getLhs() && other != prev.getRhs())
963 return prev.getResult();
975 intValue.isAllOnes())
980 intValue.isAllOnes())
985 intValue.isAllOnes())
992 return constFoldBinaryOp<IntegerAttr>(
993 adaptor.getOperands(),
994 [](APInt a,
const APInt &b) { return std::move(a) & b; });
1004 if (rhsVal.isZero())
1007 if (rhsVal.isAllOnes())
1008 return adaptor.getRhs();
1015 intValue.isAllOnes())
1016 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1020 intValue.isAllOnes())
1021 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1023 return constFoldBinaryOp<IntegerAttr>(
1024 adaptor.getOperands(),
1025 [](APInt a,
const APInt &b) { return std::move(a) | b; });
1032 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1037 if (getLhs() == getRhs())
1041 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1042 if (prev.getRhs() == getRhs())
1043 return prev.getLhs();
1044 if (prev.getLhs() == getRhs())
1045 return prev.getRhs();
1049 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1050 if (prev.getRhs() == getLhs())
1051 return prev.getLhs();
1052 if (prev.getLhs() == getLhs())
1053 return prev.getRhs();
1056 return constFoldBinaryOp<IntegerAttr>(
1057 adaptor.getOperands(),
1058 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
1063 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1070 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1072 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1073 return op.getOperand();
1074 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1075 [](
const APFloat &a) { return -a; });
1082 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1087 return constFoldBinaryOp<FloatAttr>(
1088 adaptor.getOperands(),
1089 [](
const APFloat &a,
const APFloat &b) { return a + b; });
1096 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1101 return constFoldBinaryOp<FloatAttr>(
1102 adaptor.getOperands(),
1103 [](
const APFloat &a,
const APFloat &b) { return a - b; });
1110 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1112 if (getLhs() == getRhs())
1119 return constFoldBinaryOp<FloatAttr>(
1120 adaptor.getOperands(),
1121 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1128 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1130 if (getLhs() == getRhs())
1137 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1146 if (getLhs() == getRhs())
1152 if (intValue.isMaxSignedValue())
1155 if (intValue.isMinSignedValue())
1159 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1160 [](
const APInt &a,
const APInt &b) {
1161 return llvm::APIntOps::smax(a, b);
1171 if (getLhs() == getRhs())
1177 if (intValue.isMaxValue())
1180 if (intValue.isMinValue())
1184 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1185 [](
const APInt &a,
const APInt &b) {
1186 return llvm::APIntOps::umax(a, b);
1194 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1196 if (getLhs() == getRhs())
1203 return constFoldBinaryOp<FloatAttr>(
1204 adaptor.getOperands(),
1205 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1212 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1214 if (getLhs() == getRhs())
1221 return constFoldBinaryOp<FloatAttr>(
1222 adaptor.getOperands(),
1223 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1232 if (getLhs() == getRhs())
1238 if (intValue.isMinSignedValue())
1241 if (intValue.isMaxSignedValue())
1245 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1246 [](
const APInt &a,
const APInt &b) {
1247 return llvm::APIntOps::smin(a, b);
1257 if (getLhs() == getRhs())
1263 if (intValue.isMinValue())
1266 if (intValue.isMaxValue())
1270 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1271 [](
const APInt &a,
const APInt &b) {
1272 return llvm::APIntOps::umin(a, b);
1280 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1285 return constFoldBinaryOp<FloatAttr>(
1286 adaptor.getOperands(),
1287 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1299 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1304 return constFoldBinaryOp<FloatAttr>(
1305 adaptor.getOperands(),
1306 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1318 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1319 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1320 [](
const APFloat &a,
const APFloat &b) {
1325 (void)result.mod(b);
1334 template <
typename... Types>
1340 template <
typename... ShapedTypes,
typename... ElementTypes>
1343 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1347 if (!llvm::isa<ElementTypes...>(underlyingType))
1350 return underlyingType;
1354 template <
typename... ElementTypes>
1361 template <
typename... ElementTypes>
1370 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1371 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1372 if (!rankedTensorA || !rankedTensorB)
1374 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1378 if (inputs.size() != 1 || outputs.size() != 1)
1390 template <
typename ValType,
typename Op>
1395 if (llvm::cast<ValType>(srcType).getWidth() >=
1396 llvm::cast<ValType>(dstType).getWidth())
1398 << dstType <<
" must be wider than operand type " << srcType;
1404 template <
typename ValType,
typename Op>
1409 if (llvm::cast<ValType>(srcType).getWidth() <=
1410 llvm::cast<ValType>(dstType).getWidth())
1412 << dstType <<
" must be shorter than operand type " << srcType;
1418 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1423 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1424 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1425 if (!srcType || !dstType)
1428 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1429 srcType.getIntOrFloatBitWidth());
1435 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1436 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1437 bool losesInfo =
false;
1438 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1439 if (losesInfo || status != APFloat::opOK)
1449 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1450 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1451 getInMutable().assign(lhs.getIn());
1456 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1457 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1458 adaptor.getOperands(),
getType(),
1459 [bitWidth](
const APInt &a,
bool &castStatus) {
1460 return a.zext(bitWidth);
1465 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1469 return verifyExtOp<IntegerType>(*
this);
1476 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1477 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1478 getInMutable().assign(lhs.getIn());
1483 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1484 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1485 adaptor.getOperands(),
getType(),
1486 [bitWidth](
const APInt &a,
bool &castStatus) {
1487 return a.sext(bitWidth);
1492 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1497 patterns.add<ExtSIOfExtUI>(context);
1501 return verifyExtOp<IntegerType>(*
this);
1510 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1511 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1512 if (truncFOp.getOperand().getType() ==
getType()) {
1513 arith::FastMathFlags truncFMF =
1514 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1515 bool isTruncContract =
1517 arith::FastMathFlags extFMF =
1518 getFastmath().value_or(arith::FastMathFlags::none);
1519 bool isExtContract =
1521 if (isTruncContract && isExtContract) {
1522 return truncFOp.getOperand();
1528 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1529 return constFoldCastOp<FloatAttr, FloatAttr>(
1530 adaptor.getOperands(),
getType(),
1531 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1533 if (failed(result)) {
1542 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1551 bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1553 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1557 return verifyExtOp<FloatType>(*
this);
1564 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1565 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1572 if (llvm::cast<IntegerType>(srcType).getWidth() >
1573 llvm::cast<IntegerType>(dstType).getWidth()) {
1580 if (srcType == dstType)
1585 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1586 setOperand(getOperand().getDefiningOp()->getOperand(0));
1591 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1592 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1593 adaptor.getOperands(),
getType(),
1594 [bitWidth](
const APInt &a,
bool &castStatus) {
1595 return a.trunc(bitWidth);
1600 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1606 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1611 return verifyTruncateOp<IntegerType>(*
this);
1620 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1622 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1623 Value src = extOp.getIn();
1625 auto intermediateType =
1628 if (llvm::APFloatBase::isRepresentableBy(
1629 srcType.getFloatSemantics(),
1630 intermediateType.getFloatSemantics())) {
1632 if (srcType.getWidth() > resElemType.getWidth()) {
1638 if (srcType == resElemType)
1643 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1644 return constFoldCastOp<FloatAttr, FloatAttr>(
1645 adaptor.getOperands(),
getType(),
1646 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1647 RoundingMode roundingMode =
1648 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1649 llvm::RoundingMode llvmRoundingMode =
1651 FailureOr<APFloat> result =
1653 if (failed(result)) {
1663 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1667 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1671 return verifyTruncateOp<FloatType>(*
this);
1678 bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1680 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1684 return verifyTruncateOp<FloatType>(*
this);
1693 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1702 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1709 template <
typename From,
typename To>
1714 auto srcType = getTypeIfLike<From>(inputs.front());
1715 auto dstType = getTypeIfLike<To>(outputs.back());
1717 return srcType && dstType;
1725 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1728 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1730 return constFoldCastOp<IntegerAttr, FloatAttr>(
1731 adaptor.getOperands(),
getType(),
1732 [&resEleType](
const APInt &a,
bool &castStatus) {
1733 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1734 APFloat apf(floatTy.getFloatSemantics(),
1736 apf.convertFromAPInt(a,
false,
1737 APFloat::rmNearestTiesToEven);
1747 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1750 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1752 return constFoldCastOp<IntegerAttr, FloatAttr>(
1753 adaptor.getOperands(),
getType(),
1754 [&resEleType](
const APInt &a,
bool &castStatus) {
1755 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1756 APFloat apf(floatTy.getFloatSemantics(),
1758 apf.convertFromAPInt(a,
true,
1759 APFloat::rmNearestTiesToEven);
1769 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1772 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1774 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1775 return constFoldCastOp<FloatAttr, IntegerAttr>(
1776 adaptor.getOperands(),
getType(),
1777 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1779 APSInt api(bitWidth,
true);
1780 castStatus = APFloat::opInvalidOp !=
1781 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1791 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1794 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1796 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1797 return constFoldCastOp<FloatAttr, IntegerAttr>(
1798 adaptor.getOperands(),
getType(),
1799 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1801 APSInt api(bitWidth,
false);
1802 castStatus = APFloat::opInvalidOp !=
1803 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1816 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1817 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1818 if (!srcType || !dstType)
1825 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1830 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1832 unsigned resultBitwidth = 64;
1834 resultBitwidth = intTy.getWidth();
1836 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1837 adaptor.getOperands(),
getType(),
1838 [resultBitwidth](
const APInt &a,
bool & ) {
1839 return a.sextOrTrunc(resultBitwidth);
1843 void arith::IndexCastOp::getCanonicalizationPatterns(
1845 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1852 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1857 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1859 unsigned resultBitwidth = 64;
1861 resultBitwidth = intTy.getWidth();
1863 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1864 adaptor.getOperands(),
getType(),
1865 [resultBitwidth](
const APInt &a,
bool & ) {
1866 return a.zextOrTrunc(resultBitwidth);
1870 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1872 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1883 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1884 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1885 if (!srcType || !dstType)
1891 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1893 auto operand = adaptor.getIn();
1898 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1899 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1901 if (llvm::isa<ShapedType>(resType))
1905 if (llvm::isa<ub::PoisonAttr>(operand))
1909 APInt bits = llvm::isa<FloatAttr>(operand)
1910 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1911 : llvm::cast<IntegerAttr>(operand).getValue();
1913 "trying to fold on broken IR: operands have incompatible types");
1915 if (
auto resFloatType = dyn_cast<FloatType>(resType))
1917 APFloat(resFloatType.getFloatSemantics(), bits));
1923 patterns.add<BitcastOfBitcast>(context);
1933 const APInt &lhs,
const APInt &rhs) {
1934 switch (predicate) {
1935 case arith::CmpIPredicate::eq:
1937 case arith::CmpIPredicate::ne:
1939 case arith::CmpIPredicate::slt:
1940 return lhs.slt(rhs);
1941 case arith::CmpIPredicate::sle:
1942 return lhs.sle(rhs);
1943 case arith::CmpIPredicate::sgt:
1944 return lhs.sgt(rhs);
1945 case arith::CmpIPredicate::sge:
1946 return lhs.sge(rhs);
1947 case arith::CmpIPredicate::ult:
1948 return lhs.ult(rhs);
1949 case arith::CmpIPredicate::ule:
1950 return lhs.ule(rhs);
1951 case arith::CmpIPredicate::ugt:
1952 return lhs.ugt(rhs);
1953 case arith::CmpIPredicate::uge:
1954 return lhs.uge(rhs);
1956 llvm_unreachable(
"unknown cmpi predicate kind");
1961 switch (predicate) {
1962 case arith::CmpIPredicate::eq:
1963 case arith::CmpIPredicate::sle:
1964 case arith::CmpIPredicate::sge:
1965 case arith::CmpIPredicate::ule:
1966 case arith::CmpIPredicate::uge:
1968 case arith::CmpIPredicate::ne:
1969 case arith::CmpIPredicate::slt:
1970 case arith::CmpIPredicate::sgt:
1971 case arith::CmpIPredicate::ult:
1972 case arith::CmpIPredicate::ugt:
1975 llvm_unreachable(
"unknown cmpi predicate kind");
1979 if (
auto intType = dyn_cast<IntegerType>(t)) {
1980 return intType.getWidth();
1982 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
1983 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1985 return std::nullopt;
1988 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1990 if (getLhs() == getRhs()) {
1996 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1998 std::optional<int64_t> integerWidth =
2000 if (integerWidth && integerWidth.value() == 1 &&
2001 getPredicate() == arith::CmpIPredicate::ne)
2002 return extOp.getOperand();
2004 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2006 std::optional<int64_t> integerWidth =
2008 if (integerWidth && integerWidth.value() == 1 &&
2009 getPredicate() == arith::CmpIPredicate::ne)
2010 return extOp.getOperand();
2015 getPredicate() == arith::CmpIPredicate::ne)
2022 getPredicate() == arith::CmpIPredicate::eq)
2027 if (adaptor.getLhs() && !adaptor.getRhs()) {
2029 using Pred = CmpIPredicate;
2030 const std::pair<Pred, Pred> invPreds[] = {
2031 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2032 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2033 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2034 {Pred::ne, Pred::ne},
2036 Pred origPred = getPredicate();
2037 for (
auto pred : invPreds) {
2038 if (origPred == pred.first) {
2039 setPredicate(pred.second);
2040 Value lhs = getLhs();
2041 Value rhs = getRhs();
2042 getLhsMutable().assign(rhs);
2043 getRhsMutable().assign(lhs);
2047 llvm_unreachable(
"unknown cmpi predicate kind");
2052 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2053 return constFoldBinaryOp<IntegerAttr>(
2055 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
2066 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2076 const APFloat &lhs,
const APFloat &rhs) {
2077 auto cmpResult = lhs.compare(rhs);
2078 switch (predicate) {
2079 case arith::CmpFPredicate::AlwaysFalse:
2081 case arith::CmpFPredicate::OEQ:
2082 return cmpResult == APFloat::cmpEqual;
2083 case arith::CmpFPredicate::OGT:
2084 return cmpResult == APFloat::cmpGreaterThan;
2085 case arith::CmpFPredicate::OGE:
2086 return cmpResult == APFloat::cmpGreaterThan ||
2087 cmpResult == APFloat::cmpEqual;
2088 case arith::CmpFPredicate::OLT:
2089 return cmpResult == APFloat::cmpLessThan;
2090 case arith::CmpFPredicate::OLE:
2091 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2092 case arith::CmpFPredicate::ONE:
2093 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2094 case arith::CmpFPredicate::ORD:
2095 return cmpResult != APFloat::cmpUnordered;
2096 case arith::CmpFPredicate::UEQ:
2097 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2098 case arith::CmpFPredicate::UGT:
2099 return cmpResult == APFloat::cmpUnordered ||
2100 cmpResult == APFloat::cmpGreaterThan;
2101 case arith::CmpFPredicate::UGE:
2102 return cmpResult == APFloat::cmpUnordered ||
2103 cmpResult == APFloat::cmpGreaterThan ||
2104 cmpResult == APFloat::cmpEqual;
2105 case arith::CmpFPredicate::ULT:
2106 return cmpResult == APFloat::cmpUnordered ||
2107 cmpResult == APFloat::cmpLessThan;
2108 case arith::CmpFPredicate::ULE:
2109 return cmpResult == APFloat::cmpUnordered ||
2110 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2111 case arith::CmpFPredicate::UNE:
2112 return cmpResult != APFloat::cmpEqual;
2113 case arith::CmpFPredicate::UNO:
2114 return cmpResult == APFloat::cmpUnordered;
2115 case arith::CmpFPredicate::AlwaysTrue:
2118 llvm_unreachable(
"unknown cmpf predicate kind");
2121 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2122 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2123 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2126 if (lhs && lhs.getValue().isNaN())
2128 if (rhs && rhs.getValue().isNaN())
2144 using namespace arith;
2146 case CmpFPredicate::UEQ:
2147 case CmpFPredicate::OEQ:
2148 return CmpIPredicate::eq;
2149 case CmpFPredicate::UGT:
2150 case CmpFPredicate::OGT:
2151 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2152 case CmpFPredicate::UGE:
2153 case CmpFPredicate::OGE:
2154 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2155 case CmpFPredicate::ULT:
2156 case CmpFPredicate::OLT:
2157 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2158 case CmpFPredicate::ULE:
2159 case CmpFPredicate::OLE:
2160 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2161 case CmpFPredicate::UNE:
2162 case CmpFPredicate::ONE:
2163 return CmpIPredicate::ne;
2165 llvm_unreachable(
"Unexpected predicate!");
2175 const APFloat &rhs = flt.getValue();
2183 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2184 int mantissaWidth = floatTy.getFPMantissaWidth();
2185 if (mantissaWidth <= 0)
2191 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2193 intVal = si.getIn();
2194 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2196 intVal = ui.getIn();
2203 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2204 auto intWidth = intTy.getWidth();
2207 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2212 if ((
int)intWidth > mantissaWidth) {
2214 int exponent = ilogb(rhs);
2215 if (exponent == APFloat::IEK_Inf) {
2216 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2217 if (maxExponent < (
int)valueBits) {
2224 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2233 switch (op.getPredicate()) {
2234 case CmpFPredicate::ORD:
2239 case CmpFPredicate::UNO:
2252 APFloat signedMax(rhs.getSemantics());
2253 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2254 APFloat::rmNearestTiesToEven);
2255 if (signedMax < rhs) {
2256 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2257 pred == CmpIPredicate::sle)
2268 APFloat unsignedMax(rhs.getSemantics());
2269 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2270 APFloat::rmNearestTiesToEven);
2271 if (unsignedMax < rhs) {
2272 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2273 pred == CmpIPredicate::ule)
2285 APFloat signedMin(rhs.getSemantics());
2286 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2287 APFloat::rmNearestTiesToEven);
2288 if (signedMin > rhs) {
2289 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2290 pred == CmpIPredicate::sge)
2300 APFloat unsignedMin(rhs.getSemantics());
2301 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2302 APFloat::rmNearestTiesToEven);
2303 if (unsignedMin > rhs) {
2304 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2305 pred == CmpIPredicate::uge)
2320 APSInt rhsInt(intWidth, isUnsigned);
2321 if (APFloat::opInvalidOp ==
2322 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2328 if (!rhs.isZero()) {
2329 APFloat apf(floatTy.getFloatSemantics(),
2331 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2333 bool equal = apf == rhs;
2339 case CmpIPredicate::ne:
2343 case CmpIPredicate::eq:
2347 case CmpIPredicate::ule:
2350 if (rhs.isNegative()) {
2356 case CmpIPredicate::sle:
2359 if (rhs.isNegative())
2360 pred = CmpIPredicate::slt;
2362 case CmpIPredicate::ult:
2365 if (rhs.isNegative()) {
2370 pred = CmpIPredicate::ule;
2372 case CmpIPredicate::slt:
2375 if (!rhs.isNegative())
2376 pred = CmpIPredicate::sle;
2378 case CmpIPredicate::ugt:
2381 if (rhs.isNegative()) {
2387 case CmpIPredicate::sgt:
2390 if (rhs.isNegative())
2391 pred = CmpIPredicate::sge;
2393 case CmpIPredicate::uge:
2396 if (rhs.isNegative()) {
2401 pred = CmpIPredicate::ugt;
2403 case CmpIPredicate::sge:
2406 if (!rhs.isNegative())
2407 pred = CmpIPredicate::sgt;
2417 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2439 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2455 arith::XOrIOp::create(
2456 rewriter, op.getLoc(), op.getCondition(),
2457 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2458 op.getCondition().
getType(), 1)));
2468 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2472 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2473 Value trueVal = getTrueValue();
2474 Value falseVal = getFalseValue();
2475 if (trueVal == falseVal)
2478 Value condition = getCondition();
2489 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2492 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2496 if (
getType().isSignlessInteger(1) &&
2502 auto pred = cmp.getPredicate();
2503 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2504 auto cmpLhs = cmp.getLhs();
2505 auto cmpRhs = cmp.getRhs();
2513 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2514 (cmpRhs == trueVal && cmpLhs == falseVal))
2515 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2522 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2524 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2526 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2528 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2529 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2531 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2533 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2536 for (
auto [condVal, lhsVal, rhsVal] :
2537 llvm::zip_equal(condVals, lhsVals, rhsVals))
2538 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2549 Type conditionType, resultType;
2558 conditionType = resultType;
2567 {conditionType, resultType, resultType},
2572 p <<
" " << getOperands();
2575 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2576 p << condType <<
", ";
2581 Type conditionType = getCondition().getType();
2588 if (!llvm::isa<TensorType, VectorType>(resultType))
2589 return emitOpError() <<
"expected condition to be a signless i1, but got "
2592 if (conditionType != shapedConditionType) {
2593 return emitOpError() <<
"expected condition type to have the same shape "
2594 "as the result type, expected "
2595 << shapedConditionType <<
", but got "
2604 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2609 bool bounded =
false;
2610 auto result = constFoldBinaryOp<IntegerAttr>(
2611 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2612 bounded = b.ult(b.getBitWidth());
2622 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2627 bool bounded =
false;
2628 auto result = constFoldBinaryOp<IntegerAttr>(
2629 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2630 bounded = b.ult(b.getBitWidth());
2640 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2645 bool bounded =
false;
2646 auto result = constFoldBinaryOp<IntegerAttr>(
2647 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2648 bounded = b.ult(b.getBitWidth());
2661 bool useOnlyFiniteValue) {
2663 case AtomicRMWKind::maximumf: {
2664 const llvm::fltSemantics &semantic =
2665 llvm::cast<FloatType>(resultType).getFloatSemantics();
2666 APFloat identity = useOnlyFiniteValue
2667 ? APFloat::getLargest(semantic,
true)
2668 : APFloat::getInf(semantic,
true);
2671 case AtomicRMWKind::maxnumf: {
2672 const llvm::fltSemantics &semantic =
2673 llvm::cast<FloatType>(resultType).getFloatSemantics();
2674 APFloat identity = APFloat::getNaN(semantic,
true);
2677 case AtomicRMWKind::addf:
2678 case AtomicRMWKind::addi:
2679 case AtomicRMWKind::maxu:
2680 case AtomicRMWKind::ori:
2682 case AtomicRMWKind::andi:
2685 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2686 case AtomicRMWKind::maxs:
2688 resultType, APInt::getSignedMinValue(
2689 llvm::cast<IntegerType>(resultType).getWidth()));
2690 case AtomicRMWKind::minimumf: {
2691 const llvm::fltSemantics &semantic =
2692 llvm::cast<FloatType>(resultType).getFloatSemantics();
2693 APFloat identity = useOnlyFiniteValue
2694 ? APFloat::getLargest(semantic,
false)
2695 : APFloat::getInf(semantic,
false);
2699 case AtomicRMWKind::minnumf: {
2700 const llvm::fltSemantics &semantic =
2701 llvm::cast<FloatType>(resultType).getFloatSemantics();
2702 APFloat identity = APFloat::getNaN(semantic,
false);
2705 case AtomicRMWKind::mins:
2707 resultType, APInt::getSignedMaxValue(
2708 llvm::cast<IntegerType>(resultType).getWidth()));
2709 case AtomicRMWKind::minu:
2712 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2713 case AtomicRMWKind::muli:
2715 case AtomicRMWKind::mulf:
2727 std::optional<AtomicRMWKind> maybeKind =
2730 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2731 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2732 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2733 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2734 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2735 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2737 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2738 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2739 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2740 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2741 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2742 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2743 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2744 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2745 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2746 .Default([](
Operation *op) {
return std::nullopt; });
2748 return std::nullopt;
2751 bool useOnlyFiniteValue =
false;
2752 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2753 if (fmfOpInterface) {
2754 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2755 useOnlyFiniteValue =
2756 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2764 useOnlyFiniteValue);
2770 bool useOnlyFiniteValue) {
2773 return arith::ConstantOp::create(builder, loc, attr);
2781 case AtomicRMWKind::addf:
2782 return arith::AddFOp::create(builder, loc, lhs, rhs);
2783 case AtomicRMWKind::addi:
2784 return arith::AddIOp::create(builder, loc, lhs, rhs);
2785 case AtomicRMWKind::mulf:
2786 return arith::MulFOp::create(builder, loc, lhs, rhs);
2787 case AtomicRMWKind::muli:
2788 return arith::MulIOp::create(builder, loc, lhs, rhs);
2789 case AtomicRMWKind::maximumf:
2790 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2791 case AtomicRMWKind::minimumf:
2792 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2793 case AtomicRMWKind::maxnumf:
2794 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2795 case AtomicRMWKind::minnumf:
2796 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2797 case AtomicRMWKind::maxs:
2798 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2799 case AtomicRMWKind::mins:
2800 return arith::MinSIOp::create(builder, loc, lhs, rhs);
2801 case AtomicRMWKind::maxu:
2802 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2803 case AtomicRMWKind::minu:
2804 return arith::MinUIOp::create(builder, loc, lhs, rhs);
2805 case AtomicRMWKind::ori:
2806 return arith::OrIOp::create(builder, loc, lhs, rhs);
2807 case AtomicRMWKind::andi:
2808 return arith::AndIOp::create(builder, loc, lhs, rhs);
2821 #define GET_OP_CLASSES
2822 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2828 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1224::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)