26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
44 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.
getType(), value);
67static IntegerOverflowFlagsAttr
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
98 llvm_unreachable(
"unknown cmpi predicate kind");
107static llvm::RoundingMode
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
121 llvm_unreachable(
"Unhandled rounding mode");
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
162#include "ArithCanonicalization.inc"
171 auto i1Type = IntegerType::get(type.
getContext(), 1);
172 if (
auto shapedType = dyn_cast<ShapedType>(type))
173 return shapedType.cloneWith(std::nullopt, i1Type);
174 if (llvm::isa<UnrankedTensorType>(type))
175 return UnrankedTensorType::get(i1Type);
183void arith::ConstantOp::getAsmResultNames(
186 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
187 auto intType = dyn_cast<IntegerType>(type);
190 if (intType && intType.getWidth() == 1)
191 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
194 SmallString<32> specialNameBuffer;
195 llvm::raw_svector_ostream specialName(specialNameBuffer);
196 specialName <<
'c' << intCst.getValue();
198 specialName <<
'_' << type;
199 setNameFn(getResult(), specialName.str());
201 setNameFn(getResult(),
"cst");
207LogicalResult arith::ConstantOp::verify() {
210 if (llvm::isa<IntegerType>(type) &&
211 !llvm::cast<IntegerType>(type).isSignless())
212 return emitOpError(
"integer return type must be signless");
214 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
216 "value must be an integer, float, or elements attribute");
222 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
224 "intializing scalable vectors with elements attribute is not supported"
225 " unless it's a vector splat");
229bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
231 auto typedAttr = dyn_cast<TypedAttr>(value);
232 if (!typedAttr || typedAttr.getType() != type)
235 if (llvm::isa<IntegerType>(type) &&
236 !llvm::cast<IntegerType>(type).isSignless())
239 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
242ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
243 Type type, Location loc) {
244 if (isBuildableWith(value, type))
245 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
249OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
254 arith::ConstantOp::build(builder,
result, type,
264 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
265 assert(
result &&
"builder didn't return the right type");
277 arith::ConstantOp::build(builder,
result, type,
286 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
287 assert(
result &&
"builder didn't return the right type");
298 arith::ConstantOp::build(builder,
result, type,
304 const APInt &
value) {
307 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
308 assert(
result &&
"builder didn't return the right type");
314 const APInt &
value) {
319 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
320 return constOp.getType().isSignlessInteger();
325 FloatType type,
const APFloat &
value) {
326 arith::ConstantOp::build(builder,
result, type,
333 const APFloat &
value) {
336 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
337 assert(
result &&
"builder didn't return the right type");
343 const APFloat &
value) {
348 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
349 return llvm::isa<FloatType>(constOp.getType());
364 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
365 assert(
result &&
"builder didn't return the right type");
375 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
376 return constOp.getType().isIndex();
384 "type doesn't have a zero representation");
386 assert(zeroAttr &&
"unsupported type for zero attribute");
387 return arith::ConstantOp::create(builder, loc, zeroAttr);
400 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
401 if (getRhs() == sub.getRhs())
405 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
406 if (getLhs() == sub.getRhs())
410 adaptor.getOperands(),
411 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
416 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
417 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
424std::optional<SmallVector<int64_t, 4>>
425arith::AddUIExtendedOp::getShapeForUnroll() {
426 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
427 return llvm::to_vector<4>(vt.getShape());
434 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
438arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
439 SmallVectorImpl<OpFoldResult> &results) {
440 Type overflowTy = getOverflow().getType();
446 results.push_back(getLhs());
447 results.push_back(falseValue);
456 adaptor.getOperands(),
457 [](APInt a,
const APInt &
b) { return std::move(a) + b; })) {
459 ArrayRef({sumAttr, adaptor.getLhs()}),
465 results.push_back(sumAttr);
466 results.push_back(overflowAttr);
473void arith::AddUIExtendedOp::getCanonicalizationPatterns(
474 RewritePatternSet &
patterns, MLIRContext *context) {
475 patterns.add<AddUIExtendedToAddI>(context);
482OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
484 if (getOperand(0) == getOperand(1)) {
485 auto shapedType = dyn_cast<ShapedType>(
getType());
487 if (!shapedType || shapedType.hasStaticShape())
494 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
496 if (getRhs() ==
add.getRhs())
499 if (getRhs() ==
add.getLhs())
504 adaptor.getOperands(),
505 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
508void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
509 MLIRContext *context) {
510 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
511 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
512 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
519OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
530 adaptor.getOperands(),
531 [](
const APInt &a,
const APInt &
b) { return a * b; });
534void arith::MulIOp::getAsmResultNames(
536 if (!isa<IndexType>(
getType()))
541 auto isVscale = [](Operation *op) {
542 return op && op->getName().getStringRef() ==
"vector.vscale";
545 IntegerAttr baseValue;
546 auto isVscaleExpr = [&](Value a, Value
b) {
548 isVscale(
b.getDefiningOp());
551 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
555 SmallString<32> specialNameBuffer;
556 llvm::raw_svector_ostream specialName(specialNameBuffer);
557 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
558 setNameFn(getResult(), specialName.str());
561void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
562 MLIRContext *context) {
563 patterns.add<MulIMulIConstant>(context);
570std::optional<SmallVector<int64_t, 4>>
571arith::MulSIExtendedOp::getShapeForUnroll() {
572 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
573 return llvm::to_vector<4>(vt.getShape());
578arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
579 SmallVectorImpl<OpFoldResult> &results) {
582 Attribute zero = adaptor.getRhs();
583 results.push_back(zero);
584 results.push_back(zero);
590 adaptor.getOperands(),
591 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
594 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
595 return llvm::APIntOps::mulhs(a, b);
597 assert(highAttr &&
"Unexpected constant-folding failure");
599 results.push_back(lowAttr);
600 results.push_back(highAttr);
607void arith::MulSIExtendedOp::getCanonicalizationPatterns(
608 RewritePatternSet &
patterns, MLIRContext *context) {
609 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
616std::optional<SmallVector<int64_t, 4>>
617arith::MulUIExtendedOp::getShapeForUnroll() {
618 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
619 return llvm::to_vector<4>(vt.getShape());
624arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
625 SmallVectorImpl<OpFoldResult> &results) {
628 Attribute zero = adaptor.getRhs();
629 results.push_back(zero);
630 results.push_back(zero);
638 results.push_back(getLhs());
639 results.push_back(zero);
645 adaptor.getOperands(),
646 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
649 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
650 return llvm::APIntOps::mulhu(a, b);
652 assert(highAttr &&
"Unexpected constant-folding failure");
654 results.push_back(lowAttr);
655 results.push_back(highAttr);
662void arith::MulUIExtendedOp::getCanonicalizationPatterns(
663 RewritePatternSet &
patterns, MLIRContext *context) {
664 patterns.add<MulUIExtendedToMulI>(context);
673 arith::IntegerOverflowFlags ovfFlags) {
674 auto mul =
lhs.getDefiningOp<mlir::arith::MulIOp>();
675 if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
687OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
693 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
699 [&](APInt a,
const APInt &
b) {
707 return div0 ? Attribute() :
result;
727OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
733 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
737 bool overflowOrDiv0 =
false;
739 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
740 if (overflowOrDiv0 || !b) {
741 overflowOrDiv0 = true;
744 return a.sdiv_ov(
b, overflowOrDiv0);
747 return overflowOrDiv0 ? Attribute() :
result;
774 APInt one(a.getBitWidth(), 1,
true);
775 APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
776 return val.sadd_ov(one, overflow);
783OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
788 bool overflowOrDiv0 =
false;
790 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
791 if (overflowOrDiv0 || !b) {
792 overflowOrDiv0 = true;
795 APInt quotient = a.udiv(
b);
798 APInt one(a.getBitWidth(), 1,
true);
799 return quotient.uadd_ov(one, overflowOrDiv0);
802 return overflowOrDiv0 ? Attribute() :
result;
813OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
821 bool overflowOrDiv0 =
false;
823 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
824 if (overflowOrDiv0 || !b) {
825 overflowOrDiv0 = true;
831 unsigned bits = a.getBitWidth();
832 APInt zero = APInt::getZero(bits);
833 bool aGtZero = a.sgt(zero);
834 bool bGtZero =
b.sgt(zero);
835 if (aGtZero && bGtZero) {
842 bool overflowNegA =
false;
843 bool overflowNegB =
false;
844 bool overflowDiv =
false;
845 bool overflowNegRes =
false;
846 if (!aGtZero && !bGtZero) {
848 APInt posA = zero.ssub_ov(a, overflowNegA);
849 APInt posB = zero.ssub_ov(
b, overflowNegB);
851 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
854 if (!aGtZero && bGtZero) {
856 APInt posA = zero.ssub_ov(a, overflowNegA);
857 APInt
div = posA.sdiv_ov(
b, overflowDiv);
858 APInt res = zero.ssub_ov(
div, overflowNegRes);
859 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
863 APInt posB = zero.ssub_ov(
b, overflowNegB);
864 APInt
div = a.sdiv_ov(posB, overflowDiv);
865 APInt res = zero.ssub_ov(
div, overflowNegRes);
867 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
871 return overflowOrDiv0 ? Attribute() :
result;
882OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
888 bool overflowOrDiv =
false;
890 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
892 overflowOrDiv = true;
895 return a.sfloordiv_ov(
b, overflowOrDiv);
898 return overflowOrDiv ? Attribute() :
result;
905OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
913 [&](APInt a,
const APInt &
b) {
914 if (div0 || b.isZero()) {
921 return div0 ? Attribute() :
result;
928OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
936 [&](APInt a,
const APInt &
b) {
937 if (div0 || b.isZero()) {
944 return div0 ? Attribute() :
result;
953 for (
bool reversePrev : {
false,
true}) {
954 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
955 .getDefiningOp<arith::AndIOp>();
959 Value other = (reversePrev ? op.getLhs() : op.getRhs());
960 if (other != prev.getLhs() && other != prev.getRhs())
963 return prev.getResult();
968OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
975 intValue.isAllOnes())
980 intValue.isAllOnes())
985 intValue.isAllOnes())
993 adaptor.getOperands(),
994 [](APInt a,
const APInt &
b) { return std::move(a) & b; });
1001OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1004 if (rhsVal.isZero())
1007 if (rhsVal.isAllOnes())
1008 return adaptor.getRhs();
1015 intValue.isAllOnes())
1016 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1020 intValue.isAllOnes())
1021 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1024 adaptor.getOperands(),
1025 [](APInt a,
const APInt &
b) { return std::move(a) | b; });
1032OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1037 if (getLhs() == getRhs())
1041 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1042 if (prev.getRhs() == getRhs())
1043 return prev.getLhs();
1044 if (prev.getLhs() == getRhs())
1045 return prev.getRhs();
1049 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1050 if (prev.getRhs() == getLhs())
1051 return prev.getLhs();
1052 if (prev.getLhs() == getLhs())
1053 return prev.getRhs();
1057 adaptor.getOperands(),
1058 [](APInt a,
const APInt &
b) { return std::move(a) ^ b; });
1061void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1062 MLIRContext *context) {
1063 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1070OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1072 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1073 return op.getOperand();
1075 [](
const APFloat &a) { return -a; });
1082OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1088 adaptor.getOperands(),
1089 [](
const APFloat &a,
const APFloat &
b) { return a + b; });
1096OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1102 adaptor.getOperands(),
1103 [](
const APFloat &a,
const APFloat &
b) { return a - b; });
1110OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1112 if (getLhs() == getRhs())
1120 adaptor.getOperands(),
1121 [](
const APFloat &a,
const APFloat &
b) { return llvm::maximum(a, b); });
1128OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1130 if (getLhs() == getRhs())
1144OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1146 if (getLhs() == getRhs())
1152 if (intValue.isMaxSignedValue())
1155 if (intValue.isMinSignedValue())
1160 [](
const APInt &a,
const APInt &
b) {
1161 return llvm::APIntOps::smax(a, b);
1169OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1171 if (getLhs() == getRhs())
1177 if (intValue.isMaxValue())
1180 if (intValue.isMinValue())
1185 [](
const APInt &a,
const APInt &
b) {
1186 return llvm::APIntOps::umax(a, b);
1194OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1196 if (getLhs() == getRhs())
1204 adaptor.getOperands(),
1205 [](
const APFloat &a,
const APFloat &
b) { return llvm::minimum(a, b); });
1212OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1214 if (getLhs() == getRhs())
1222 adaptor.getOperands(),
1223 [](
const APFloat &a,
const APFloat &
b) { return llvm::minnum(a, b); });
1230OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1232 if (getLhs() == getRhs())
1238 if (intValue.isMinSignedValue())
1241 if (intValue.isMaxSignedValue())
1246 [](
const APInt &a,
const APInt &
b) {
1247 return llvm::APIntOps::smin(a, b);
1255OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1257 if (getLhs() == getRhs())
1263 if (intValue.isMinValue())
1266 if (intValue.isMaxValue())
1271 [](
const APInt &a,
const APInt &
b) {
1272 return llvm::APIntOps::umin(a, b);
1280OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1285 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1286 arith::FastMathFlags::nsz)) {
1293 adaptor.getOperands(),
1294 [](
const APFloat &a,
const APFloat &
b) { return a * b; });
1297void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1298 MLIRContext *context) {
1306OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1312 adaptor.getOperands(),
1313 [](
const APFloat &a,
const APFloat &
b) { return a / b; });
1316void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1317 MLIRContext *context) {
1325OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1327 [](
const APFloat &a,
const APFloat &
b) {
1332 (void)result.mod(b);
1341template <
typename... Types>
1347template <
typename... ShapedTypes,
typename... ElementTypes>
1350 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1354 if (!llvm::isa<ElementTypes...>(underlyingType))
1357 return underlyingType;
1361template <
typename... ElementTypes>
1368template <
typename... ElementTypes>
1377 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1378 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1379 if (!rankedTensorA || !rankedTensorB)
1381 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1385 if (inputs.size() != 1 || outputs.size() != 1)
1397template <
typename ValType,
typename Op>
1402 if (llvm::cast<ValType>(srcType).getWidth() >=
1403 llvm::cast<ValType>(dstType).getWidth())
1405 << dstType <<
" must be wider than operand type " << srcType;
1411template <
typename ValType,
typename Op>
1416 if (llvm::cast<ValType>(srcType).getWidth() <=
1417 llvm::cast<ValType>(dstType).getWidth())
1419 << dstType <<
" must be shorter than operand type " << srcType;
1425template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1430 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1431 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1432 if (!srcType || !dstType)
1435 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1436 srcType.getIntOrFloatBitWidth());
1442 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1443 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1444 bool losesInfo =
false;
1445 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1446 if (losesInfo || status != APFloat::opOK)
1456OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1457 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1458 getInMutable().assign(
lhs.getIn());
1463 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1465 adaptor.getOperands(),
getType(),
1466 [bitWidth](
const APInt &a,
bool &castStatus) {
1467 return a.zext(bitWidth);
1475LogicalResult arith::ExtUIOp::verify() {
1483OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1484 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1485 getInMutable().assign(
lhs.getIn());
1490 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1492 adaptor.getOperands(),
getType(),
1493 [bitWidth](
const APInt &a,
bool &castStatus) {
1494 return a.sext(bitWidth);
1502void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1503 MLIRContext *context) {
1504 patterns.add<ExtSIOfExtUI>(context);
1507LogicalResult arith::ExtSIOp::verify() {
1517OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1518 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1519 if (truncFOp.getOperand().getType() ==
getType()) {
1520 arith::FastMathFlags truncFMF =
1521 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1522 bool isTruncContract =
1523 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1524 arith::FastMathFlags extFMF =
1525 getFastmath().value_or(arith::FastMathFlags::none);
1526 bool isExtContract =
1527 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1528 if (isTruncContract && isExtContract) {
1529 return truncFOp.getOperand();
1535 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1537 adaptor.getOperands(),
getType(),
1538 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1558bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1563LogicalResult arith::ScalingExtFOp::verify() {
1571OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1574 Value src = getOperand().getDefiningOp()->getOperand(0);
1579 if (llvm::cast<IntegerType>(srcType).getWidth() >
1580 llvm::cast<IntegerType>(dstType).getWidth()) {
1587 if (srcType == dstType)
1593 setOperand(getOperand().getDefiningOp()->getOperand(0));
1598 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1600 adaptor.getOperands(),
getType(),
1601 [bitWidth](
const APInt &a,
bool &castStatus) {
1602 return a.trunc(bitWidth);
1610void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1611 MLIRContext *context) {
1613 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1617LogicalResult arith::TruncIOp::verify() {
1627OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1629 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1630 Value src = extOp.getIn();
1632 auto intermediateType =
1635 if (llvm::APFloatBase::isRepresentableBy(
1636 srcType.getFloatSemantics(),
1637 intermediateType.getFloatSemantics())) {
1639 if (srcType.getWidth() > resElemType.getWidth()) {
1645 if (srcType == resElemType)
1650 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1652 adaptor.getOperands(),
getType(),
1653 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1654 RoundingMode roundingMode =
1655 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1656 llvm::RoundingMode llvmRoundingMode =
1658 FailureOr<APFloat>
result =
1668void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1669 MLIRContext *context) {
1670 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1677LogicalResult arith::TruncFOp::verify() {
1685bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1690LogicalResult arith::ScalingTruncFOp::verify() {
1698void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1699 MLIRContext *context) {
1700 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1707void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1708 MLIRContext *context) {
1709 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1716template <
typename From,
typename To>
1724 return srcType && dstType;
1735OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1738 adaptor.getOperands(),
getType(),
1739 [&resEleType](
const APInt &a,
bool &castStatus) {
1740 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1741 APFloat apf(floatTy.getFloatSemantics(),
1742 APInt::getZero(floatTy.getWidth()));
1743 apf.convertFromAPInt(a,
false,
1744 APFloat::rmNearestTiesToEven);
1757OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1760 adaptor.getOperands(),
getType(),
1761 [&resEleType](
const APInt &a,
bool &castStatus) {
1762 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1763 APFloat apf(floatTy.getFloatSemantics(),
1764 APInt::getZero(floatTy.getWidth()));
1765 apf.convertFromAPInt(a,
true,
1766 APFloat::rmNearestTiesToEven);
1779OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1781 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1783 adaptor.getOperands(),
getType(),
1784 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1786 APSInt api(bitWidth,
true);
1787 castStatus = APFloat::opInvalidOp !=
1788 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1801OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1803 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1805 adaptor.getOperands(),
getType(),
1806 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1808 APSInt api(bitWidth,
false);
1809 castStatus = APFloat::opInvalidOp !=
1810 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1825 if (!srcType || !dstType)
1832bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1837OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1839 unsigned resultBitwidth = 64;
1841 resultBitwidth = intTy.getWidth();
1844 adaptor.getOperands(),
getType(),
1845 [resultBitwidth](
const APInt &a,
bool & ) {
1846 return a.sextOrTrunc(resultBitwidth);
1850void arith::IndexCastOp::getCanonicalizationPatterns(
1851 RewritePatternSet &
patterns, MLIRContext *context) {
1852 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1859bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1864OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1866 unsigned resultBitwidth = 64;
1868 resultBitwidth = intTy.getWidth();
1871 adaptor.getOperands(),
getType(),
1872 [resultBitwidth](
const APInt &a,
bool & ) {
1873 return a.zextOrTrunc(resultBitwidth);
1877void arith::IndexCastUIOp::getCanonicalizationPatterns(
1878 RewritePatternSet &
patterns, MLIRContext *context) {
1879 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1892 if (!srcType || !dstType)
1898OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1900 auto operand = adaptor.getIn();
1905 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1906 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1908 if (llvm::isa<ShapedType>(resType))
1912 if (llvm::isa<ub::PoisonAttr>(operand))
1916 APInt bits = llvm::isa<FloatAttr>(operand)
1917 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1918 : llvm::cast<IntegerAttr>(operand).getValue();
1920 "trying to fold on broken IR: operands have incompatible types");
1922 if (
auto resFloatType = dyn_cast<FloatType>(resType))
1923 return FloatAttr::get(resType,
1924 APFloat(resFloatType.getFloatSemantics(), bits));
1925 return IntegerAttr::get(resType, bits);
1928void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1929 MLIRContext *context) {
1930 patterns.add<BitcastOfBitcast>(context);
1940 const APInt &
lhs,
const APInt &
rhs) {
1941 switch (predicate) {
1942 case arith::CmpIPredicate::eq:
1944 case arith::CmpIPredicate::ne:
1946 case arith::CmpIPredicate::slt:
1948 case arith::CmpIPredicate::sle:
1950 case arith::CmpIPredicate::sgt:
1952 case arith::CmpIPredicate::sge:
1954 case arith::CmpIPredicate::ult:
1956 case arith::CmpIPredicate::ule:
1958 case arith::CmpIPredicate::ugt:
1960 case arith::CmpIPredicate::uge:
1963 llvm_unreachable(
"unknown cmpi predicate kind");
1968 switch (predicate) {
1969 case arith::CmpIPredicate::eq:
1970 case arith::CmpIPredicate::sle:
1971 case arith::CmpIPredicate::sge:
1972 case arith::CmpIPredicate::ule:
1973 case arith::CmpIPredicate::uge:
1975 case arith::CmpIPredicate::ne:
1976 case arith::CmpIPredicate::slt:
1977 case arith::CmpIPredicate::sgt:
1978 case arith::CmpIPredicate::ult:
1979 case arith::CmpIPredicate::ugt:
1982 llvm_unreachable(
"unknown cmpi predicate kind");
1986 if (
auto intType = dyn_cast<IntegerType>(t)) {
1987 return intType.getWidth();
1989 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
1990 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1992 return std::nullopt;
1995OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1997 if (getLhs() == getRhs()) {
2003 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2005 std::optional<int64_t> integerWidth =
2007 if (integerWidth && integerWidth.value() == 1 &&
2008 getPredicate() == arith::CmpIPredicate::ne)
2009 return extOp.getOperand();
2011 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2013 std::optional<int64_t> integerWidth =
2015 if (integerWidth && integerWidth.value() == 1 &&
2016 getPredicate() == arith::CmpIPredicate::ne)
2017 return extOp.getOperand();
2022 getPredicate() == arith::CmpIPredicate::ne)
2029 getPredicate() == arith::CmpIPredicate::eq)
2034 if (adaptor.getLhs() && !adaptor.getRhs()) {
2036 using Pred = CmpIPredicate;
2037 const std::pair<Pred, Pred> invPreds[] = {
2038 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2039 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2040 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2041 {Pred::ne, Pred::ne},
2043 Pred origPred = getPredicate();
2044 for (
auto pred : invPreds) {
2045 if (origPred == pred.first) {
2046 setPredicate(pred.second);
2047 Value
lhs = getLhs();
2048 Value
rhs = getRhs();
2049 getLhsMutable().assign(
rhs);
2050 getRhsMutable().assign(
lhs);
2054 llvm_unreachable(
"unknown cmpi predicate kind");
2059 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2062 [pred = getPredicate()](
const APInt &
lhs,
const APInt &
rhs) {
2071void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
2072 MLIRContext *context) {
2073 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2083 const APFloat &
lhs,
const APFloat &
rhs) {
2084 auto cmpResult =
lhs.compare(
rhs);
2085 switch (predicate) {
2086 case arith::CmpFPredicate::AlwaysFalse:
2088 case arith::CmpFPredicate::OEQ:
2089 return cmpResult == APFloat::cmpEqual;
2090 case arith::CmpFPredicate::OGT:
2091 return cmpResult == APFloat::cmpGreaterThan;
2092 case arith::CmpFPredicate::OGE:
2093 return cmpResult == APFloat::cmpGreaterThan ||
2094 cmpResult == APFloat::cmpEqual;
2095 case arith::CmpFPredicate::OLT:
2096 return cmpResult == APFloat::cmpLessThan;
2097 case arith::CmpFPredicate::OLE:
2098 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2099 case arith::CmpFPredicate::ONE:
2100 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2101 case arith::CmpFPredicate::ORD:
2102 return cmpResult != APFloat::cmpUnordered;
2103 case arith::CmpFPredicate::UEQ:
2104 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2105 case arith::CmpFPredicate::UGT:
2106 return cmpResult == APFloat::cmpUnordered ||
2107 cmpResult == APFloat::cmpGreaterThan;
2108 case arith::CmpFPredicate::UGE:
2109 return cmpResult == APFloat::cmpUnordered ||
2110 cmpResult == APFloat::cmpGreaterThan ||
2111 cmpResult == APFloat::cmpEqual;
2112 case arith::CmpFPredicate::ULT:
2113 return cmpResult == APFloat::cmpUnordered ||
2114 cmpResult == APFloat::cmpLessThan;
2115 case arith::CmpFPredicate::ULE:
2116 return cmpResult == APFloat::cmpUnordered ||
2117 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2118 case arith::CmpFPredicate::UNE:
2119 return cmpResult != APFloat::cmpEqual;
2120 case arith::CmpFPredicate::UNO:
2121 return cmpResult == APFloat::cmpUnordered;
2122 case arith::CmpFPredicate::AlwaysTrue:
2125 llvm_unreachable(
"unknown cmpf predicate kind");
2129 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2130 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2133 if (
lhs &&
lhs.getValue().isNaN())
2135 if (
rhs &&
rhs.getValue().isNaN())
2151 using namespace arith;
2153 case CmpFPredicate::UEQ:
2154 case CmpFPredicate::OEQ:
2155 return CmpIPredicate::eq;
2156 case CmpFPredicate::UGT:
2157 case CmpFPredicate::OGT:
2158 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2159 case CmpFPredicate::UGE:
2160 case CmpFPredicate::OGE:
2161 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2162 case CmpFPredicate::ULT:
2163 case CmpFPredicate::OLT:
2164 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2165 case CmpFPredicate::ULE:
2166 case CmpFPredicate::OLE:
2167 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2168 case CmpFPredicate::UNE:
2169 case CmpFPredicate::ONE:
2170 return CmpIPredicate::ne;
2172 llvm_unreachable(
"Unexpected predicate!");
2182 const APFloat &
rhs = flt.getValue();
2190 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2191 int mantissaWidth = floatTy.getFPMantissaWidth();
2192 if (mantissaWidth <= 0)
2198 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2200 intVal = si.getIn();
2201 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2203 intVal = ui.getIn();
2210 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2211 auto intWidth = intTy.getWidth();
2214 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2219 if ((
int)intWidth > mantissaWidth) {
2221 int exponent = ilogb(
rhs);
2222 if (exponent == APFloat::IEK_Inf) {
2223 int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
2224 if (maxExponent < (
int)valueBits) {
2231 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2240 switch (op.getPredicate()) {
2241 case CmpFPredicate::ORD:
2246 case CmpFPredicate::UNO:
2259 APFloat signedMax(
rhs.getSemantics());
2260 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2261 APFloat::rmNearestTiesToEven);
2262 if (signedMax <
rhs) {
2263 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2264 pred == CmpIPredicate::sle)
2275 APFloat unsignedMax(
rhs.getSemantics());
2276 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2277 APFloat::rmNearestTiesToEven);
2278 if (unsignedMax <
rhs) {
2279 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2280 pred == CmpIPredicate::ule)
2292 APFloat signedMin(
rhs.getSemantics());
2293 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2294 APFloat::rmNearestTiesToEven);
2295 if (signedMin >
rhs) {
2296 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2297 pred == CmpIPredicate::sge)
2307 APFloat unsignedMin(
rhs.getSemantics());
2308 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2309 APFloat::rmNearestTiesToEven);
2310 if (unsignedMin >
rhs) {
2311 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2312 pred == CmpIPredicate::uge)
2327 APSInt rhsInt(intWidth, isUnsigned);
2328 if (APFloat::opInvalidOp ==
2329 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2335 if (!
rhs.isZero()) {
2336 APFloat apf(floatTy.getFloatSemantics(),
2337 APInt::getZero(floatTy.getWidth()));
2338 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2340 bool equal = apf ==
rhs;
2346 case CmpIPredicate::ne:
2350 case CmpIPredicate::eq:
2354 case CmpIPredicate::ule:
2357 if (
rhs.isNegative()) {
2363 case CmpIPredicate::sle:
2366 if (
rhs.isNegative())
2367 pred = CmpIPredicate::slt;
2369 case CmpIPredicate::ult:
2372 if (
rhs.isNegative()) {
2377 pred = CmpIPredicate::ule;
2379 case CmpIPredicate::slt:
2382 if (!
rhs.isNegative())
2383 pred = CmpIPredicate::sle;
2385 case CmpIPredicate::ugt:
2388 if (
rhs.isNegative()) {
2394 case CmpIPredicate::sgt:
2397 if (
rhs.isNegative())
2398 pred = CmpIPredicate::sge;
2400 case CmpIPredicate::uge:
2403 if (
rhs.isNegative()) {
2408 pred = CmpIPredicate::ugt;
2410 case CmpIPredicate::sge:
2413 if (!
rhs.isNegative())
2414 pred = CmpIPredicate::sgt;
2424 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2430void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
2431 MLIRContext *context) {
2432 patterns.insert<CmpFIntToFPConst>(context);
2446 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2462 arith::XOrIOp::create(
2463 rewriter, op.getLoc(), op.getCondition(),
2465 op.getCondition().
getType(), 1)));
2473void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2474 MLIRContext *context) {
2475 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2476 SelectI1ToNot, SelectToExtUI>(context);
2479OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2480 Value trueVal = getTrueValue();
2481 Value falseVal = getFalseValue();
2482 if (trueVal == falseVal)
2485 Value condition = getCondition();
2496 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2499 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2503 if (
getType().isSignlessInteger(1) &&
2509 auto pred = cmp.getPredicate();
2510 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2511 auto cmpLhs = cmp.getLhs();
2512 auto cmpRhs = cmp.getRhs();
2520 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2521 (cmpRhs == trueVal && cmpLhs == falseVal))
2522 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2529 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2531 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2533 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2534 SmallVector<Attribute> results;
2535 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2536 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2537 cond.value_end<BoolAttr>());
2538 auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
2539 lhs.value_end<Attribute>());
2540 auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
2541 rhs.value_end<Attribute>());
2543 for (
auto [condVal, lhsVal, rhsVal] :
2544 llvm::zip_equal(condVals, lhsVals, rhsVals))
2545 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2555ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
2556 Type conditionType, resultType;
2557 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2565 conditionType = resultType;
2572 result.addTypes(resultType);
2574 {conditionType, resultType, resultType},
2578void arith::SelectOp::print(OpAsmPrinter &p) {
2579 p <<
" " << getOperands();
2582 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2583 p << condType <<
", ";
2587LogicalResult arith::SelectOp::verify() {
2588 Type conditionType = getCondition().getType();
2595 if (!llvm::isa<TensorType, VectorType>(resultType))
2596 return emitOpError() <<
"expected condition to be a signless i1, but got "
2599 if (conditionType != shapedConditionType) {
2600 return emitOpError() <<
"expected condition type to have the same shape "
2601 "as the result type, expected "
2602 << shapedConditionType <<
", but got "
2611OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2616 bool bounded =
false;
2618 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2619 bounded = b.ult(b.getBitWidth());
2622 return bounded ?
result : Attribute();
2629OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2634 bool bounded =
false;
2636 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2637 bounded = b.ult(b.getBitWidth());
2640 return bounded ?
result : Attribute();
2647OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2652 bool bounded =
false;
2654 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2655 bounded = b.ult(b.getBitWidth());
2658 return bounded ?
result : Attribute();
2668 bool useOnlyFiniteValue) {
2670 case AtomicRMWKind::maximumf: {
2671 const llvm::fltSemantics &semantic =
2672 llvm::cast<FloatType>(resultType).getFloatSemantics();
2673 APFloat identity = useOnlyFiniteValue
2674 ? APFloat::getLargest(semantic,
true)
2675 : APFloat::getInf(semantic,
true);
2678 case AtomicRMWKind::maxnumf: {
2679 const llvm::fltSemantics &semantic =
2680 llvm::cast<FloatType>(resultType).getFloatSemantics();
2681 APFloat identity = APFloat::getNaN(semantic,
true);
2684 case AtomicRMWKind::addf:
2685 case AtomicRMWKind::addi:
2686 case AtomicRMWKind::maxu:
2687 case AtomicRMWKind::ori:
2688 case AtomicRMWKind::xori:
2690 case AtomicRMWKind::andi:
2693 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2694 case AtomicRMWKind::maxs:
2696 resultType, APInt::getSignedMinValue(
2697 llvm::cast<IntegerType>(resultType).getWidth()));
2698 case AtomicRMWKind::minimumf: {
2699 const llvm::fltSemantics &semantic =
2700 llvm::cast<FloatType>(resultType).getFloatSemantics();
2701 APFloat identity = useOnlyFiniteValue
2702 ? APFloat::getLargest(semantic,
false)
2703 : APFloat::getInf(semantic,
false);
2707 case AtomicRMWKind::minnumf: {
2708 const llvm::fltSemantics &semantic =
2709 llvm::cast<FloatType>(resultType).getFloatSemantics();
2710 APFloat identity = APFloat::getNaN(semantic,
false);
2713 case AtomicRMWKind::mins:
2715 resultType, APInt::getSignedMaxValue(
2716 llvm::cast<IntegerType>(resultType).getWidth()));
2717 case AtomicRMWKind::minu:
2720 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2721 case AtomicRMWKind::muli:
2723 case AtomicRMWKind::mulf:
2735 std::optional<AtomicRMWKind> maybeKind =
2738 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2739 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2740 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2741 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2742 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2743 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2745 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2746 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2747 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
2748 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2749 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2750 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2751 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2752 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2753 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2754 .Default(std::nullopt);
2756 return std::nullopt;
2759 bool useOnlyFiniteValue =
false;
2760 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2761 if (fmfOpInterface) {
2762 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2763 useOnlyFiniteValue =
2764 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2772 useOnlyFiniteValue);
2778 bool useOnlyFiniteValue) {
2781 return arith::ConstantOp::create(builder, loc, attr);
2789 case AtomicRMWKind::addf:
2790 return arith::AddFOp::create(builder, loc,
lhs,
rhs);
2791 case AtomicRMWKind::addi:
2792 return arith::AddIOp::create(builder, loc,
lhs,
rhs);
2793 case AtomicRMWKind::mulf:
2794 return arith::MulFOp::create(builder, loc,
lhs,
rhs);
2795 case AtomicRMWKind::muli:
2796 return arith::MulIOp::create(builder, loc,
lhs,
rhs);
2797 case AtomicRMWKind::maximumf:
2798 return arith::MaximumFOp::create(builder, loc,
lhs,
rhs);
2799 case AtomicRMWKind::minimumf:
2800 return arith::MinimumFOp::create(builder, loc,
lhs,
rhs);
2801 case AtomicRMWKind::maxnumf:
2802 return arith::MaxNumFOp::create(builder, loc,
lhs,
rhs);
2803 case AtomicRMWKind::minnumf:
2804 return arith::MinNumFOp::create(builder, loc,
lhs,
rhs);
2805 case AtomicRMWKind::maxs:
2806 return arith::MaxSIOp::create(builder, loc,
lhs,
rhs);
2807 case AtomicRMWKind::mins:
2808 return arith::MinSIOp::create(builder, loc,
lhs,
rhs);
2809 case AtomicRMWKind::maxu:
2810 return arith::MaxUIOp::create(builder, loc,
lhs,
rhs);
2811 case AtomicRMWKind::minu:
2812 return arith::MinUIOp::create(builder, loc,
lhs,
rhs);
2813 case AtomicRMWKind::ori:
2814 return arith::OrIOp::create(builder, loc,
lhs,
rhs);
2815 case AtomicRMWKind::andi:
2816 return arith::AndIOp::create(builder, loc,
lhs,
rhs);
2817 case AtomicRMWKind::xori:
2818 return arith::XOrIOp::create(builder, loc,
lhs,
rhs);
2831#define GET_OP_CLASSES
2832#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2838#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static LogicalResult verifyTruncateOp(Op op)
static Type getElementType(Type type)
Determine the element type of type.
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
llvm::function_ref< Fn > function_ref
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.