23 #include "llvm/ADT/APInt.h"
24 #include "llvm/ADT/APSInt.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/SmallVector.h"
39 llvm::cast<IntegerAttr>(lhs).getInt() +
40 llvm::cast<IntegerAttr>(rhs).getInt());
46 llvm::cast<IntegerAttr>(lhs).getInt() -
47 llvm::cast<IntegerAttr>(rhs).getInt());
53 case arith::CmpIPredicate::eq:
54 return arith::CmpIPredicate::ne;
55 case arith::CmpIPredicate::ne:
56 return arith::CmpIPredicate::eq;
57 case arith::CmpIPredicate::slt:
58 return arith::CmpIPredicate::sge;
59 case arith::CmpIPredicate::sle:
60 return arith::CmpIPredicate::sgt;
61 case arith::CmpIPredicate::sgt:
62 return arith::CmpIPredicate::sle;
63 case arith::CmpIPredicate::sge:
64 return arith::CmpIPredicate::slt;
65 case arith::CmpIPredicate::ult:
66 return arith::CmpIPredicate::uge;
67 case arith::CmpIPredicate::ule:
68 return arith::CmpIPredicate::ugt;
69 case arith::CmpIPredicate::ugt:
70 return arith::CmpIPredicate::ule;
71 case arith::CmpIPredicate::uge:
72 return arith::CmpIPredicate::ult;
74 llvm_unreachable(
"unknown cmpi predicate kind");
95 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
96 return intAttr.getValue();
98 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
99 if (llvm::isa<IntegerType>(splatAttr.getElementType()))
100 return splatAttr.getSplatValue<APInt>();
110 #include "ArithCanonicalization.inc"
120 if (
auto tensorType = llvm::dyn_cast<RankedTensorType>(type))
122 if (llvm::isa<UnrankedTensorType>(type))
124 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
126 vectorType.getNumScalableDims());
134 void arith::ConstantOp::getAsmResultNames(
136 auto type = getType();
137 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
138 auto intType = llvm::dyn_cast<IntegerType>(type);
141 if (intType && intType.getWidth() == 1)
142 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
146 llvm::raw_svector_ostream specialName(specialNameBuffer);
147 specialName <<
'c' << intCst.getValue();
149 specialName <<
'_' << type;
150 setNameFn(getResult(), specialName.str());
152 setNameFn(getResult(),
"cst");
159 auto type = getType();
161 if (getValue().getType() != type) {
162 return emitOpError() <<
"value type " << getValue().getType()
163 <<
" must match return type: " << type;
166 if (llvm::isa<IntegerType>(type) &&
167 !llvm::cast<IntegerType>(type).isSignless())
168 return emitOpError(
"integer return type must be signless");
170 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
172 "value must be an integer, float, or elements attribute");
177 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
179 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
180 if (!typedAttr || typedAttr.getType() != type)
183 if (llvm::isa<IntegerType>(type) &&
184 !llvm::cast<IntegerType>(type).isSignless())
187 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
192 if (isBuildableWith(value, type))
193 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
197 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
200 int64_t value,
unsigned width) {
202 arith::ConstantOp::build(builder, result, type,
207 int64_t value,
Type type) {
209 "ConstantIntOp can only have signless integer type values");
210 arith::ConstantOp::build(builder, result, type,
215 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
216 return constOp.getType().isSignlessInteger();
222 arith::ConstantOp::build(builder, result, type,
227 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
228 return llvm::isa<FloatType>(constOp.getType());
234 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
239 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
240 return constOp.getType().isIndex();
254 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
255 if (getRhs() == sub.getRhs())
259 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
260 if (getLhs() == sub.getRhs())
263 return constFoldBinaryOp<IntegerAttr>(
264 adaptor.getOperands(),
265 [](APInt a,
const APInt &b) { return std::move(a) + b; });
270 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
271 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
278 std::optional<SmallVector<int64_t, 4>>
279 arith::AddUIExtendedOp::getShapeForUnroll() {
280 if (
auto vt = llvm::dyn_cast<VectorType>(getType(0)))
281 return llvm::to_vector<4>(vt.getShape());
288 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
292 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
294 Type overflowTy = getOverflow().getType();
300 results.push_back(getLhs());
301 results.push_back(falseValue);
309 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
310 adaptor.getOperands(),
311 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
312 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
313 ArrayRef({sumAttr, adaptor.getLhs()}),
319 results.push_back(sumAttr);
320 results.push_back(overflowAttr);
327 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
329 patterns.
add<AddUIExtendedToAddI>(context);
338 if (getOperand(0) == getOperand(1))
344 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
346 if (getRhs() == add.getRhs())
349 if (getRhs() == add.getLhs())
353 return constFoldBinaryOp<IntegerAttr>(
354 adaptor.getOperands(),
355 [](APInt a,
const APInt &b) { return std::move(a) - b; });
360 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
361 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
362 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
375 return getOperand(0);
379 return constFoldBinaryOp<IntegerAttr>(
380 adaptor.getOperands(),
381 [](
const APInt &a,
const APInt &b) { return a * b; });
388 std::optional<SmallVector<int64_t, 4>>
389 arith::MulSIExtendedOp::getShapeForUnroll() {
390 if (
auto vt = llvm::dyn_cast<VectorType>(getType(0)))
391 return llvm::to_vector<4>(vt.getShape());
396 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
401 results.push_back(zero);
402 results.push_back(zero);
407 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
408 adaptor.getOperands(),
409 [](
const APInt &a,
const APInt &b) { return a * b; })) {
411 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
412 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
413 unsigned bitWidth = a.getBitWidth();
414 APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
415 return fullProduct.extractBits(bitWidth, bitWidth);
417 assert(highAttr &&
"Unexpected constant-folding failure");
419 results.push_back(lowAttr);
420 results.push_back(highAttr);
427 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
429 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
436 std::optional<SmallVector<int64_t, 4>>
437 arith::MulUIExtendedOp::getShapeForUnroll() {
438 if (
auto vt = llvm::dyn_cast<VectorType>(getType(0)))
439 return llvm::to_vector<4>(vt.getShape());
444 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
449 results.push_back(zero);
450 results.push_back(zero);
458 results.push_back(getLhs());
459 results.push_back(zero);
464 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
465 adaptor.getOperands(),
466 [](
const APInt &a,
const APInt &b) { return a * b; })) {
468 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
469 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
470 unsigned bitWidth = a.getBitWidth();
471 APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
472 return fullProduct.extractBits(bitWidth, bitWidth);
474 assert(highAttr &&
"Unexpected constant-folding failure");
476 results.push_back(lowAttr);
477 results.push_back(highAttr);
484 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
486 patterns.
add<MulUIExtendedToMulI>(context);
493 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
500 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
501 [&](APInt a,
const APInt &b) {
522 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
528 bool overflowOrDiv0 =
false;
529 auto result = constFoldBinaryOp<IntegerAttr>(
530 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
531 if (overflowOrDiv0 || !b) {
532 overflowOrDiv0 = true;
535 return a.sdiv_ov(b, overflowOrDiv0);
538 return overflowOrDiv0 ?
Attribute() : result;
542 bool mayHaveUB =
true;
548 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
560 APInt one(a.getBitWidth(), 1,
true);
561 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
562 return val.sadd_ov(one, overflow);
569 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
574 bool overflowOrDiv0 =
false;
575 auto result = constFoldBinaryOp<IntegerAttr>(
576 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
577 if (overflowOrDiv0 || !b) {
578 overflowOrDiv0 = true;
581 APInt quotient = a.udiv(b);
584 APInt one(a.getBitWidth(), 1,
true);
585 return quotient.uadd_ov(one, overflowOrDiv0);
588 return overflowOrDiv0 ?
Attribute() : result;
601 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
607 bool overflowOrDiv0 =
false;
608 auto result = constFoldBinaryOp<IntegerAttr>(
609 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
610 if (overflowOrDiv0 || !b) {
611 overflowOrDiv0 = true;
617 unsigned bits = a.getBitWidth();
619 bool aGtZero = a.sgt(zero);
620 bool bGtZero = b.sgt(zero);
621 if (aGtZero && bGtZero) {
625 if (!aGtZero && !bGtZero) {
627 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
628 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
631 if (!aGtZero && bGtZero) {
633 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
634 APInt div = posA.sdiv_ov(b, overflowOrDiv0);
635 return zero.ssub_ov(div, overflowOrDiv0);
638 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
639 APInt div = a.sdiv_ov(posB, overflowOrDiv0);
640 return zero.ssub_ov(div, overflowOrDiv0);
643 return overflowOrDiv0 ?
Attribute() : result;
647 bool mayHaveUB =
true;
653 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
662 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
668 bool overflowOrDiv0 =
false;
669 auto result = constFoldBinaryOp<IntegerAttr>(
670 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
671 if (overflowOrDiv0 || !b) {
672 overflowOrDiv0 = true;
678 unsigned bits = a.getBitWidth();
680 bool aGtZero = a.sgt(zero);
681 bool bGtZero = b.sgt(zero);
682 if (aGtZero && bGtZero) {
684 return a.sdiv_ov(b, overflowOrDiv0);
686 if (!aGtZero && !bGtZero) {
688 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
689 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
690 return posA.sdiv_ov(posB, overflowOrDiv0);
692 if (!aGtZero && bGtZero) {
694 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
696 return zero.ssub_ov(
ceil, overflowOrDiv0);
699 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
701 return zero.ssub_ov(
ceil, overflowOrDiv0);
704 return overflowOrDiv0 ?
Attribute() : result;
711 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
718 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
719 [&](APInt a,
const APInt &b) {
720 if (div0 || b.isZero()) {
734 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
741 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
742 [&](APInt a,
const APInt &b) {
743 if (div0 || b.isZero()) {
759 for (
bool reversePrev : {
false,
true}) {
760 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
761 .getDefiningOp<arith::AndIOp>();
765 Value other = (reversePrev ? op.getLhs() : op.getRhs());
766 if (other != prev.getLhs() && other != prev.getRhs())
769 return prev.getResult();
785 intValue.isAllOnes())
790 intValue.isAllOnes())
797 return constFoldBinaryOp<IntegerAttr>(
798 adaptor.getOperands(),
799 [](APInt a,
const APInt &b) { return std::move(a) & b; });
811 if (
auto rhsAttr = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()))
812 if (rhsAttr.getValue().isAllOnes())
819 intValue.isAllOnes())
820 return getRhs().getDefiningOp<XOrIOp>().getRhs();
824 intValue.isAllOnes())
825 return getLhs().getDefiningOp<XOrIOp>().getRhs();
827 return constFoldBinaryOp<IntegerAttr>(
828 adaptor.getOperands(),
829 [](APInt a,
const APInt &b) { return std::move(a) | b; });
841 if (getLhs() == getRhs())
845 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
846 if (prev.getRhs() == getRhs())
847 return prev.getLhs();
848 if (prev.getLhs() == getRhs())
849 return prev.getRhs();
853 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
854 if (prev.getRhs() == getLhs())
855 return prev.getLhs();
856 if (prev.getLhs() == getLhs())
857 return prev.getRhs();
860 return constFoldBinaryOp<IntegerAttr>(
861 adaptor.getOperands(),
862 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
867 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
876 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
878 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
879 [](
const APFloat &a) { return -a; });
891 return constFoldBinaryOp<FloatAttr>(
892 adaptor.getOperands(),
893 [](
const APFloat &a,
const APFloat &b) { return a + b; });
905 return constFoldBinaryOp<FloatAttr>(
906 adaptor.getOperands(),
907 [](
const APFloat &a,
const APFloat &b) { return a - b; });
916 if (getLhs() == getRhs())
923 return constFoldBinaryOp<FloatAttr>(
924 adaptor.getOperands(),
925 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
934 if (getLhs() == getRhs())
940 intValue.isMaxSignedValue())
945 intValue.isMinSignedValue())
948 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
949 [](
const APInt &a,
const APInt &b) {
950 return llvm::APIntOps::smax(a, b);
960 if (getLhs() == getRhs())
972 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
973 [](
const APInt &a,
const APInt &b) {
974 return llvm::APIntOps::umax(a, b);
984 if (getLhs() == getRhs())
991 return constFoldBinaryOp<FloatAttr>(
992 adaptor.getOperands(),
993 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1002 if (getLhs() == getRhs())
1008 intValue.isMinSignedValue())
1013 intValue.isMaxSignedValue())
1016 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1017 [](
const APInt &a,
const APInt &b) {
1018 return llvm::APIntOps::smin(a, b);
1028 if (getLhs() == getRhs())
1040 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1041 [](
const APInt &a,
const APInt &b) {
1042 return llvm::APIntOps::umin(a, b);
1050 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1055 return constFoldBinaryOp<FloatAttr>(
1056 adaptor.getOperands(),
1057 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1062 patterns.
add<MulFOfNegF>(context);
1069 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1074 return constFoldBinaryOp<FloatAttr>(
1075 adaptor.getOperands(),
1076 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1081 patterns.
add<DivFOfNegF>(context);
1088 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1089 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1090 [](
const APFloat &a,
const APFloat &b) {
1092 (void)result.remainder(b);
1101 template <
typename... Types>
1107 template <
typename... ShapedTypes,
typename... ElementTypes>
1110 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1114 if (!llvm::isa<ElementTypes...>(underlyingType))
1117 return underlyingType;
1121 template <
typename... ElementTypes>
1128 template <
typename... ElementTypes>
1136 return inputs.size() == 1 && outputs.size() == 1 &&
1145 template <
typename ValType,
typename Op>
1150 if (llvm::cast<ValType>(srcType).getWidth() >=
1151 llvm::cast<ValType>(dstType).getWidth())
1153 << dstType <<
" must be wider than operand type " << srcType;
1159 template <
typename ValType,
typename Op>
1164 if (llvm::cast<ValType>(srcType).getWidth() <=
1165 llvm::cast<ValType>(dstType).getWidth())
1167 << dstType <<
" must be shorter than operand type " << srcType;
1173 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1178 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1179 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1180 if (!srcType || !dstType)
1183 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1184 srcType.getIntOrFloatBitWidth());
1191 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1192 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1193 getInMutable().assign(lhs.getIn());
1198 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1199 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1200 adaptor.getOperands(), getType(),
1201 [bitWidth](
const APInt &a,
bool &castStatus) {
1202 return a.zext(bitWidth);
1207 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1211 return verifyExtOp<IntegerType>(*
this);
1218 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1219 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1220 getInMutable().assign(lhs.getIn());
1225 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1226 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1227 adaptor.getOperands(), getType(),
1228 [bitWidth](
const APInt &a,
bool &castStatus) {
1229 return a.sext(bitWidth);
1234 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1239 patterns.
add<ExtSIOfExtUI>(context);
1243 return verifyExtOp<IntegerType>(*
this);
1251 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1252 auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
1257 return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
1261 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1270 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1271 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1278 if (llvm::cast<IntegerType>(srcType).getWidth() >
1279 llvm::cast<IntegerType>(dstType).getWidth()) {
1289 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1290 setOperand(getOperand().getDefiningOp()->getOperand(0));
1295 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1296 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1297 adaptor.getOperands(), getType(),
1298 [bitWidth](
const APInt &a,
bool &castStatus) {
1299 return a.trunc(bitWidth);
1304 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1309 patterns.
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1310 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1315 return verifyTruncateOp<IntegerType>(*
this);
1324 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1325 auto constOperand = adaptor.getIn();
1326 if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
1330 double sourceValue =
1331 llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
1335 if (sourceValue == targetAttr.getValue().convertToDouble())
1342 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1346 return verifyTruncateOp<FloatType>(*
this);
1355 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1364 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1371 template <
typename From,
typename To>
1376 auto srcType = getTypeIfLike<From>(inputs.front());
1377 auto dstType = getTypeIfLike<To>(outputs.back());
1379 return srcType && dstType;
1387 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1390 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1392 return constFoldCastOp<IntegerAttr, FloatAttr>(
1393 adaptor.getOperands(), getType(),
1394 [&resEleType](
const APInt &a,
bool &castStatus) {
1395 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1398 apf.convertFromAPInt(a,
false,
1399 APFloat::rmNearestTiesToEven);
1409 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1412 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1414 return constFoldCastOp<IntegerAttr, FloatAttr>(
1415 adaptor.getOperands(), getType(),
1416 [&resEleType](
const APInt &a,
bool &castStatus) {
1417 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1420 apf.convertFromAPInt(a,
true,
1421 APFloat::rmNearestTiesToEven);
1430 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1433 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1435 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1436 return constFoldCastOp<FloatAttr, IntegerAttr>(
1437 adaptor.getOperands(), getType(),
1438 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1440 APSInt api(bitWidth,
true);
1441 castStatus = APFloat::opInvalidOp !=
1442 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1452 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1455 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1457 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1458 return constFoldCastOp<FloatAttr, IntegerAttr>(
1459 adaptor.getOperands(), getType(),
1460 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1462 APSInt api(bitWidth,
false);
1463 castStatus = APFloat::opInvalidOp !=
1464 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1477 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1478 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1479 if (!srcType || !dstType)
1486 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1491 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1493 unsigned resultBitwidth = 64;
1495 resultBitwidth = intTy.getWidth();
1497 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1498 adaptor.getOperands(), getType(),
1499 [resultBitwidth](
const APInt &a,
bool & ) {
1500 return a.sextOrTrunc(resultBitwidth);
1504 void arith::IndexCastOp::getCanonicalizationPatterns(
1506 patterns.
add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1513 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1518 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1520 unsigned resultBitwidth = 64;
1522 resultBitwidth = intTy.getWidth();
1524 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1525 adaptor.getOperands(), getType(),
1526 [resultBitwidth](
const APInt &a,
bool & ) {
1527 return a.zextOrTrunc(resultBitwidth);
1531 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1533 patterns.
add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1545 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1547 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1548 if (!srcType || !dstType)
1554 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1555 auto resType = getType();
1556 auto operand = adaptor.getIn();
1561 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1562 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1564 if (llvm::isa<ShapedType>(resType))
1568 APInt bits = llvm::isa<FloatAttr>(operand)
1569 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1570 : llvm::cast<IntegerAttr>(operand).getValue();
1572 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1574 APFloat(resFloatType.getFloatSemantics(), bits));
1580 patterns.
add<BitcastOfBitcast>(context);
1590 const APInt &lhs,
const APInt &rhs) {
1591 switch (predicate) {
1592 case arith::CmpIPredicate::eq:
1594 case arith::CmpIPredicate::ne:
1596 case arith::CmpIPredicate::slt:
1597 return lhs.slt(rhs);
1598 case arith::CmpIPredicate::sle:
1599 return lhs.sle(rhs);
1600 case arith::CmpIPredicate::sgt:
1601 return lhs.sgt(rhs);
1602 case arith::CmpIPredicate::sge:
1603 return lhs.sge(rhs);
1604 case arith::CmpIPredicate::ult:
1605 return lhs.ult(rhs);
1606 case arith::CmpIPredicate::ule:
1607 return lhs.ule(rhs);
1608 case arith::CmpIPredicate::ugt:
1609 return lhs.ugt(rhs);
1610 case arith::CmpIPredicate::uge:
1611 return lhs.uge(rhs);
1613 llvm_unreachable(
"unknown cmpi predicate kind");
1618 switch (predicate) {
1619 case arith::CmpIPredicate::eq:
1620 case arith::CmpIPredicate::sle:
1621 case arith::CmpIPredicate::sge:
1622 case arith::CmpIPredicate::ule:
1623 case arith::CmpIPredicate::uge:
1625 case arith::CmpIPredicate::ne:
1626 case arith::CmpIPredicate::slt:
1627 case arith::CmpIPredicate::sgt:
1628 case arith::CmpIPredicate::ult:
1629 case arith::CmpIPredicate::ugt:
1632 llvm_unreachable(
"unknown cmpi predicate kind");
1637 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
1644 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1645 return intType.getWidth();
1647 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1648 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1650 return std::nullopt;
1653 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1655 if (getLhs() == getRhs()) {
1661 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1663 std::optional<int64_t> integerWidth =
1665 if (integerWidth && integerWidth.value() == 1 &&
1666 getPredicate() == arith::CmpIPredicate::ne)
1667 return extOp.getOperand();
1669 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1671 std::optional<int64_t> integerWidth =
1673 if (integerWidth && integerWidth.value() == 1 &&
1674 getPredicate() == arith::CmpIPredicate::ne)
1675 return extOp.getOperand();
1680 if (adaptor.getLhs() && !adaptor.getRhs()) {
1682 using Pred = CmpIPredicate;
1683 const std::pair<Pred, Pred> invPreds[] = {
1684 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1685 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1686 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1687 {Pred::ne, Pred::ne},
1689 Pred origPred = getPredicate();
1690 for (
auto pred : invPreds) {
1691 if (origPred == pred.first) {
1692 setPredicate(pred.second);
1693 Value lhs = getLhs();
1694 Value rhs = getRhs();
1695 getLhsMutable().assign(rhs);
1696 getRhsMutable().assign(lhs);
1700 llvm_unreachable(
"unknown cmpi predicate kind");
1705 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1706 return constFoldBinaryOp<IntegerAttr>(
1708 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1719 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
1729 const APFloat &lhs,
const APFloat &rhs) {
1730 auto cmpResult = lhs.compare(rhs);
1731 switch (predicate) {
1732 case arith::CmpFPredicate::AlwaysFalse:
1734 case arith::CmpFPredicate::OEQ:
1735 return cmpResult == APFloat::cmpEqual;
1736 case arith::CmpFPredicate::OGT:
1737 return cmpResult == APFloat::cmpGreaterThan;
1738 case arith::CmpFPredicate::OGE:
1739 return cmpResult == APFloat::cmpGreaterThan ||
1740 cmpResult == APFloat::cmpEqual;
1741 case arith::CmpFPredicate::OLT:
1742 return cmpResult == APFloat::cmpLessThan;
1743 case arith::CmpFPredicate::OLE:
1744 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1745 case arith::CmpFPredicate::ONE:
1746 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1747 case arith::CmpFPredicate::ORD:
1748 return cmpResult != APFloat::cmpUnordered;
1749 case arith::CmpFPredicate::UEQ:
1750 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1751 case arith::CmpFPredicate::UGT:
1752 return cmpResult == APFloat::cmpUnordered ||
1753 cmpResult == APFloat::cmpGreaterThan;
1754 case arith::CmpFPredicate::UGE:
1755 return cmpResult == APFloat::cmpUnordered ||
1756 cmpResult == APFloat::cmpGreaterThan ||
1757 cmpResult == APFloat::cmpEqual;
1758 case arith::CmpFPredicate::ULT:
1759 return cmpResult == APFloat::cmpUnordered ||
1760 cmpResult == APFloat::cmpLessThan;
1761 case arith::CmpFPredicate::ULE:
1762 return cmpResult == APFloat::cmpUnordered ||
1763 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1764 case arith::CmpFPredicate::UNE:
1765 return cmpResult != APFloat::cmpEqual;
1766 case arith::CmpFPredicate::UNO:
1767 return cmpResult == APFloat::cmpUnordered;
1768 case arith::CmpFPredicate::AlwaysTrue:
1771 llvm_unreachable(
"unknown cmpf predicate kind");
1774 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1775 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1776 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1779 if (lhs && lhs.getValue().isNaN())
1781 if (rhs && rhs.getValue().isNaN())
1797 using namespace arith;
1799 case CmpFPredicate::UEQ:
1800 case CmpFPredicate::OEQ:
1801 return CmpIPredicate::eq;
1802 case CmpFPredicate::UGT:
1803 case CmpFPredicate::OGT:
1804 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1805 case CmpFPredicate::UGE:
1806 case CmpFPredicate::OGE:
1807 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1808 case CmpFPredicate::ULT:
1809 case CmpFPredicate::OLT:
1810 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1811 case CmpFPredicate::ULE:
1812 case CmpFPredicate::OLE:
1813 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1814 case CmpFPredicate::UNE:
1815 case CmpFPredicate::ONE:
1816 return CmpIPredicate::ne;
1818 llvm_unreachable(
"Unexpected predicate!");
1828 const APFloat &rhs = flt.getValue();
1836 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
1838 if (mantissaWidth <= 0)
1844 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1846 intVal = si.getIn();
1847 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1849 intVal = ui.getIn();
1856 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
1857 auto intWidth = intTy.getWidth();
1860 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1865 if ((
int)intWidth > mantissaWidth) {
1867 int exponent = ilogb(rhs);
1868 if (exponent == APFloat::IEK_Inf) {
1869 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1870 if (maxExponent < (
int)valueBits) {
1877 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
1886 switch (op.getPredicate()) {
1887 case CmpFPredicate::ORD:
1892 case CmpFPredicate::UNO:
1905 APFloat signedMax(rhs.getSemantics());
1906 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
1907 APFloat::rmNearestTiesToEven);
1908 if (signedMax < rhs) {
1909 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1910 pred == CmpIPredicate::sle)
1921 APFloat unsignedMax(rhs.getSemantics());
1922 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
1923 APFloat::rmNearestTiesToEven);
1924 if (unsignedMax < rhs) {
1925 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1926 pred == CmpIPredicate::ule)
1938 APFloat signedMin(rhs.getSemantics());
1939 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
1940 APFloat::rmNearestTiesToEven);
1941 if (signedMin > rhs) {
1942 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1943 pred == CmpIPredicate::sge)
1953 APFloat unsignedMin(rhs.getSemantics());
1954 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
1955 APFloat::rmNearestTiesToEven);
1956 if (unsignedMin > rhs) {
1957 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1958 pred == CmpIPredicate::uge)
1973 APSInt rhsInt(intWidth, isUnsigned);
1974 if (APFloat::opInvalidOp ==
1975 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1981 if (!rhs.isZero()) {
1984 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1986 bool equal = apf == rhs;
1992 case CmpIPredicate::ne:
1996 case CmpIPredicate::eq:
2000 case CmpIPredicate::ule:
2003 if (rhs.isNegative()) {
2009 case CmpIPredicate::sle:
2012 if (rhs.isNegative())
2013 pred = CmpIPredicate::slt;
2015 case CmpIPredicate::ult:
2018 if (rhs.isNegative()) {
2023 pred = CmpIPredicate::ule;
2025 case CmpIPredicate::slt:
2028 if (!rhs.isNegative())
2029 pred = CmpIPredicate::sle;
2031 case CmpIPredicate::ugt:
2034 if (rhs.isNegative()) {
2040 case CmpIPredicate::sgt:
2043 if (rhs.isNegative())
2044 pred = CmpIPredicate::sge;
2046 case CmpIPredicate::uge:
2049 if (rhs.isNegative()) {
2054 pred = CmpIPredicate::ugt;
2056 case CmpIPredicate::sge:
2059 if (!rhs.isNegative())
2060 pred = CmpIPredicate::sgt;
2070 rewriter.
create<ConstantOp>(
2098 if (!op.getType().isInteger(1))
2101 Value falseConstant =
2102 rewriter.
create<arith::ConstantIntOp>(op.
getLoc(),
true, 1);
2103 Value notCondition = rewriter.
create<arith::XOrIOp>(
2104 op.
getLoc(), op.getCondition(), falseConstant);
2107 op.
getLoc(), op.getCondition(), op.getTrueValue());
2109 op.getFalseValue());
2122 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2138 rewriter.
create<arith::XOrIOp>(
2139 op.
getLoc(), op.getCondition(),
2140 rewriter.
create<arith::ConstantIntOp>(
2141 op.
getLoc(), 1, op.getCondition().getType())));
2154 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2155 Value trueVal = getTrueValue();
2156 Value falseVal = getFalseValue();
2157 if (trueVal == falseVal)
2160 Value condition = getCondition();
2175 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2176 auto pred = cmp.getPredicate();
2177 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2178 auto cmpLhs = cmp.getLhs();
2179 auto cmpRhs = cmp.getRhs();
2187 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2188 (cmpRhs == trueVal && cmpLhs == falseVal))
2189 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2196 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2198 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2200 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2202 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2203 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2205 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2207 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2210 for (
auto [condVal, lhsVal, rhsVal] :
2211 llvm::zip_equal(condVals, lhsVals, rhsVals))
2212 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2223 Type conditionType, resultType;
2232 conditionType = resultType;
2241 {conditionType, resultType, resultType},
2246 p <<
" " << getOperands();
2249 if (ShapedType condType =
2250 llvm::dyn_cast<ShapedType>(getCondition().getType()))
2251 p << condType <<
", ";
2256 Type conditionType = getCondition().getType();
2262 Type resultType = getType();
2263 if (!llvm::isa<TensorType, VectorType>(resultType))
2264 return emitOpError() <<
"expected condition to be a signless i1, but got "
2267 if (conditionType != shapedConditionType) {
2268 return emitOpError() <<
"expected condition type to have the same shape "
2269 "as the result type, expected "
2270 << shapedConditionType <<
", but got "
2279 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2284 bool bounded =
false;
2285 auto result = constFoldBinaryOp<IntegerAttr>(
2286 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2287 bounded = b.ule(b.getBitWidth());
2297 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2302 bool bounded =
false;
2303 auto result = constFoldBinaryOp<IntegerAttr>(
2304 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2305 bounded = b.ule(b.getBitWidth());
2315 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2320 bool bounded =
false;
2321 auto result = constFoldBinaryOp<IntegerAttr>(
2322 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2323 bounded = b.ule(b.getBitWidth());
2337 case AtomicRMWKind::maxf:
2340 APFloat::getInf(llvm::cast<FloatType>(resultType).getFloatSemantics(),
2342 case AtomicRMWKind::addf:
2343 case AtomicRMWKind::addi:
2344 case AtomicRMWKind::maxu:
2345 case AtomicRMWKind::ori:
2347 case AtomicRMWKind::andi:
2350 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2351 case AtomicRMWKind::maxs:
2353 resultType, APInt::getSignedMinValue(
2354 llvm::cast<IntegerType>(resultType).getWidth()));
2355 case AtomicRMWKind::minf:
2358 APFloat::getInf(llvm::cast<FloatType>(resultType).getFloatSemantics(),
2360 case AtomicRMWKind::mins:
2362 resultType, APInt::getSignedMaxValue(
2363 llvm::cast<IntegerType>(resultType).getWidth()));
2364 case AtomicRMWKind::minu:
2367 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2368 case AtomicRMWKind::muli:
2370 case AtomicRMWKind::mulf:
2384 return builder.
create<arith::ConstantOp>(loc, attr);
2392 case AtomicRMWKind::addf:
2393 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2394 case AtomicRMWKind::addi:
2395 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2396 case AtomicRMWKind::mulf:
2397 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2398 case AtomicRMWKind::muli:
2399 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2400 case AtomicRMWKind::maxf:
2401 return builder.
create<arith::MaxFOp>(loc, lhs, rhs);
2402 case AtomicRMWKind::minf:
2403 return builder.
create<arith::MinFOp>(loc, lhs, rhs);
2404 case AtomicRMWKind::maxs:
2405 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2406 case AtomicRMWKind::mins:
2407 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2408 case AtomicRMWKind::maxu:
2409 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2410 case AtomicRMWKind::minu:
2411 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2412 case AtomicRMWKind::ori:
2413 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2414 case AtomicRMWKind::andi:
2415 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2428 #define GET_OP_CLASSES
2429 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2435 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
std::tuple< Types... > * type_list
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class provides support for representing a failure result, or a valid value of type T.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value associated with an AtomicRMWKind op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
MPInt ceil(const Fraction &f)
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)