20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/MathExtras.h"
27 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
41 static std::pair<APInt, APInt>
getHalves(
const APInt &value,
42 unsigned newBitWidth) {
43 APInt low = value.extractBits(newBitWidth, 0);
44 APInt high = value.extractBits(newBitWidth, newBitWidth);
45 return {std::move(low), std::move(high)};
54 if (type.getShape().size() == 1)
55 return type.getElementType();
57 auto newShape = to_vector(type.getShape());
71 assert(lastOffset < shape.back() &&
"Offset out of bounds");
74 if (shape.size() == 1)
75 return vector::ExtractOp::create(rewriter, loc, input, lastOffset);
78 offsets.back() = lastOffset;
79 auto sizes = llvm::to_vector(shape);
83 return vector::ExtractStridedSliceOp::create(rewriter, loc, input, offsets,
89 static std::pair<Value, Value>
100 auto vecTy = dyn_cast<VectorType>(input.
getType());
106 assert(shape.size() >= 2 &&
"Expected vector with at list two dims");
107 assert(shape.back() == 1 &&
"Expected the last vector dim to be x1");
109 auto newVecTy =
VectorType::get(shape.drop_back(), vecTy.getElementType());
110 return vector::ShapeCastOp::create(rewriter, loc, newVecTy, input);
117 auto vecTy = dyn_cast<VectorType>(input.
getType());
122 auto newShape = llvm::to_vector(vecTy.getShape());
123 newShape.push_back(1);
125 return vector::ShapeCastOp::create(rewriter, loc, newTy, input);
133 int64_t lastOffset) {
135 assert(lastOffset < shape.back() &&
"Offset out of bounds");
138 if (isa<IntegerType>(source.
getType()))
139 return vector::InsertOp::create(rewriter, loc, source, dest, lastOffset);
142 offsets.back() = lastOffset;
144 return vector::InsertStridedSliceOp::create(rewriter, loc, source, dest,
155 Location loc, VectorType resultType,
159 assert(!resultShape.empty() &&
"Result expected to have dimensions");
160 assert(resultShape.back() ==
static_cast<int64_t
>(resultComponents.size()) &&
161 "Wrong number of result components");
179 matchAndRewrite(arith::ConstantOp op, OpAdaptor,
181 Type oldType = op.getType();
182 auto newType = getTypeConverter()->convertType<VectorType>(oldType);
185 op, llvm::formatv(
"unsupported type: {0}", op.getType()));
187 unsigned newBitWidth = newType.getElementTypeBitWidth();
190 if (
auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
191 auto [low, high] =
getHalves(intAttr.getValue(), newBitWidth);
197 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
199 getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
200 int64_t numSplatElems = splatAttr.getNumElements();
202 values.reserve(numSplatElems * 2);
203 for (int64_t i = 0; i < numSplatElems; ++i) {
204 values.push_back(low);
205 values.push_back(high);
213 if (
auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
214 int64_t numElems = elemsAttr.getNumElements();
216 values.reserve(numElems * 2);
217 for (
const APInt &origVal : elemsAttr.getValues<APInt>()) {
218 auto [low, high] =
getHalves(origVal, newBitWidth);
219 values.push_back(std::move(low));
220 values.push_back(std::move(high));
229 "unhandled constant attribute");
241 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
244 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
247 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
251 auto [lhsElem0, lhsElem1] =
253 auto [rhsElem0, rhsElem1] =
257 arith::AddUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
259 arith::ExtUIOp::create(rewriter, loc, newElemTy, lowSum.getOverflow());
261 Value high0 = arith::AddIOp::create(rewriter, loc, overflowVal, lhsElem1);
262 Value high = arith::AddIOp::create(rewriter, loc, high0, rhsElem1);
276 template <
typename BinaryOp>
282 matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
285 auto newTy = this->getTypeConverter()->template convertType<VectorType>(
289 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
291 auto [lhsElem0, lhsElem1] =
293 auto [rhsElem0, rhsElem1] =
296 Value resElem0 = BinaryOp::create(rewriter, loc, lhsElem0, rhsElem0);
297 Value resElem1 = BinaryOp::create(rewriter, loc, lhsElem1, rhsElem1);
311 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
312 using P = arith::CmpIPredicate;
331 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
335 getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
338 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
340 arith::CmpIPredicate highPred = adaptor.getPredicate();
341 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
343 auto [lhsElem0, lhsElem1] =
345 auto [rhsElem0, rhsElem1] =
349 arith::CmpIOp::create(rewriter, loc, lowPred, lhsElem0, rhsElem0);
351 arith::CmpIOp::create(rewriter, loc, highPred, lhsElem1, rhsElem1);
355 case arith::CmpIPredicate::eq: {
356 cmpResult = arith::AndIOp::create(rewriter, loc, lowCmp, highCmp);
359 case arith::CmpIPredicate::ne: {
360 cmpResult = arith::OrIOp::create(rewriter, loc, lowCmp, highCmp);
365 Value highEq = arith::CmpIOp::create(
366 rewriter, loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
368 arith::SelectOp::create(rewriter, loc, highEq, lowCmp, highCmp);
373 assert(cmpResult &&
"Unhandled case");
387 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
390 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
393 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
395 auto [lhsElem0, lhsElem1] =
397 auto [rhsElem0, rhsElem1] =
404 arith::MulUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
405 Value mulLowHi = arith::MulIOp::create(rewriter, loc, lhsElem0, rhsElem1);
406 Value mulHiLow = arith::MulIOp::create(rewriter, loc, lhsElem1, rhsElem0);
408 Value resLow = mulLowLow.getLow();
410 arith::AddIOp::create(rewriter, loc, mulLowLow.getHigh(), mulLowHi);
411 resHi = arith::AddIOp::create(rewriter, loc, resHi, mulHiLow);
428 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
431 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
434 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
443 loc, newResultComponentTy, newOperand);
444 Value operandZeroCst =
446 Value signBit = arith::CmpIOp::create(
447 rewriter, loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
449 arith::ExtSIOp::create(rewriter, loc, newResultComponentTy, signBit);
466 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
469 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
472 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
480 loc, newResultComponentTy, newOperand);
492 template <
typename SourceOp, arith::CmpIPredicate CmpPred>
497 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
501 Type oldTy = op.getType();
502 auto newTy = dyn_cast_or_null<VectorType>(
503 this->getTypeConverter()->convertType(oldTy));
506 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
511 arith::CmpIOp::create(rewriter, loc, CmpPred, op.getLhs(), op.getRhs());
522 static bool isIndexOrIndexVector(
Type type) {
523 if (isa<IndexType>(type))
526 if (
auto vectorTy = dyn_cast<VectorType>(type))
527 if (isa<IndexType>(vectorTy.getElementType()))
533 template <
typename CastOp>
538 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
540 Type resultType = op.getType();
541 if (!isIndexOrIndexVector(resultType))
545 Type inType = op.getIn().getType();
547 this->getTypeConverter()->template convertType<VectorType>(inType);
550 loc, llvm::formatv(
"unsupported type: {0}", inType));
560 template <
typename CastOp,
typename ExtensionOp>
565 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
567 Type inType = op.getIn().getType();
568 if (!isIndexOrIndexVector(inType))
572 auto *typeConverter =
573 this->
template getTypeConverter<arith::WideIntEmulationConverter>();
575 Type resultType = op.getType();
576 auto newTy = typeConverter->template convertType<VectorType>(resultType);
579 loc, llvm::formatv(
"unsupported type: {0}", resultType));
583 rewriter.
getIntegerType(typeConverter->getMaxTargetIntBitWidth());
584 if (
auto vecTy = dyn_cast<VectorType>(resultType))
589 Value underlyingVal =
590 CastOp::create(rewriter, loc, narrowTy, adaptor.getIn());
604 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
607 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
610 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
612 auto [trueElem0, trueElem1] =
614 auto [falseElem0, falseElem1] =
619 arith::SelectOp::create(rewriter, loc, cond, trueElem0, falseElem0);
621 arith::SelectOp::create(rewriter, loc, cond, trueElem1, falseElem1);
637 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
641 Type oldTy = op.getType();
642 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
645 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
649 unsigned newBitWidth = newTy.getElementTypeBitWidth();
651 auto [lhsElem0, lhsElem1] =
683 Value illegalElemShift = arith::CmpIOp::create(
684 rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
687 arith::ShLIOp::create(rewriter, loc, lhsElem0, rhsElem0);
688 Value resElem0 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
689 zeroCst, shiftedElem0);
691 Value cappedShiftAmount = arith::SelectOp::create(
692 rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
693 Value rightShiftAmount =
694 arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
696 arith::ShRUIOp::create(rewriter, loc, lhsElem0, rightShiftAmount);
697 Value overshotShiftAmount =
698 arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
700 arith::ShLIOp::create(rewriter, loc, lhsElem0, overshotShiftAmount);
703 arith::ShLIOp::create(rewriter, loc, lhsElem1, rhsElem0);
704 Value resElem1High = arith::SelectOp::create(
705 rewriter, loc, illegalElemShift, zeroCst, shiftedElem1);
706 Value resElem1Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
707 shiftedLeft, shiftedRight);
709 arith::OrIOp::create(rewriter, loc, resElem1Low, resElem1High);
726 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
730 Type oldTy = op.getType();
731 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
734 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
738 unsigned newBitWidth = newTy.getElementTypeBitWidth();
740 auto [lhsElem0, lhsElem1] =
772 Value illegalElemShift = arith::CmpIOp::create(
773 rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
776 arith::ShRUIOp::create(rewriter, loc, lhsElem0, rhsElem0);
777 Value resElem0Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
778 zeroCst, shiftedElem0);
780 arith::ShRUIOp::create(rewriter, loc, lhsElem1, rhsElem0);
781 Value resElem1 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
782 zeroCst, shiftedElem1);
784 Value cappedShiftAmount = arith::SelectOp::create(
785 rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
786 Value leftShiftAmount =
787 arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
789 arith::ShLIOp::create(rewriter, loc, lhsElem1, leftShiftAmount);
790 Value overshotShiftAmount =
791 arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
793 arith::ShRUIOp::create(rewriter, loc, lhsElem1, overshotShiftAmount);
795 Value resElem0High = arith::SelectOp::create(
796 rewriter, loc, illegalElemShift, shiftedRight, shiftedLeft);
798 arith::OrIOp::create(rewriter, loc, resElem0Low, resElem0High);
815 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
819 Type oldTy = op.getType();
820 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
823 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
829 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
835 Value signBit = arith::CmpIOp::create(
836 rewriter, loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
842 Value allSign = arith::ExtSIOp::create(rewriter, loc, oldTy, signBit);
845 Value numNonSignExtBits =
846 arith::SubIOp::create(rewriter, loc, maxShift, rhsElem0);
849 arith::ExtUIOp::create(rewriter, loc, oldTy, numNonSignExtBits);
851 arith::ShLIOp::create(rewriter, loc, allSign, numNonSignExtBits);
855 arith::ShRUIOp::create(rewriter, loc, op.getLhs(), op.getRhs());
856 Value shrsi = arith::OrIOp::create(rewriter, loc, shrui, signBits);
860 Value isNoop = arith::CmpIOp::create(
861 rewriter, loc, arith::CmpIPredicate::eq, rhsElem0, elemZero);
878 matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
881 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
884 loc, llvm::formatv(
"unsupported type: {}", op.getType()));
888 auto [lhsElem0, lhsElem1] =
890 auto [rhsElem0, rhsElem1] =
895 Value low = arith::SubIOp::create(rewriter, loc, lhsElem0, rhsElem0);
897 Value carry0 = arith::CmpIOp::create(
898 rewriter, loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
899 Value carryVal = arith::ExtUIOp::create(rewriter, loc, newElemTy, carry0);
901 Value high0 = arith::SubIOp::create(rewriter, loc, lhsElem1, carryVal);
902 Value high = arith::SubIOp::create(rewriter, loc, high0, rhsElem1);
918 matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
922 Value in = op.getIn();
924 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
927 loc, llvm::formatv(
"unsupported type: {0}", oldTy));
936 Value isNeg = arith::CmpIOp::create(rewriter, loc,
937 arith::CmpIPredicate::slt, in, zeroCst);
938 Value neg = arith::SubIOp::create(rewriter, loc, zeroCst, in);
939 Value abs = arith::SelectOp::create(rewriter, loc, isNeg, neg, in);
941 Value absResult = arith::UIToFPOp::create(rewriter, loc, op.getType(), abs);
942 Value negResult = arith::NegFOp::create(rewriter, loc, absResult);
957 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
961 Type oldTy = op.getIn().getType();
962 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
965 loc, llvm::formatv(
"unsupported type: {0}", oldTy));
966 unsigned newBitWidth = newTy.getElementTypeBitWidth();
988 Value hiEqZero = arith::CmpIOp::create(
989 rewriter, loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
991 Type resultTy = op.getType();
993 Value lowFp = arith::UIToFPOp::create(rewriter, loc, resultTy, lowInt);
994 Value hiFp = arith::UIToFPOp::create(rewriter, loc, resultTy, hiInt);
996 int64_t pow2Int = int64_t(1) << newBitWidth;
998 rewriter.
getFloatAttr(resultElemTy,
static_cast<double>(pow2Int));
999 if (
auto vecTy = dyn_cast<VectorType>(resultTy))
1003 arith::ConstantOp::create(rewriter, loc, resultTy, pow2Attr);
1005 Value hiVal = arith::MulFOp::create(rewriter, loc, hiFp, pow2Val);
1006 Value result = arith::AddFOp::create(rewriter, loc, lowFp, hiVal);
1021 matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
1025 Value inFp = adaptor.getIn();
1028 Type intTy = op.getType();
1030 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1033 loc, llvm::formatv(
"unsupported type: {}", intTy));
1041 Value zeroCst = arith::ConstantOp::create(rewriter, loc, zeroAttr);
1046 Value isNeg = arith::CmpFOp::create(
1047 rewriter, loc, arith::CmpFPredicate::OLT, inFp, zeroCst);
1048 Value negInFp = arith::NegFOp::create(rewriter, loc, inFp);
1050 Value absVal = arith::SelectOp::create(rewriter, loc, isNeg, negInFp, inFp);
1053 Value res = arith::FPToUIOp::create(rewriter, loc, intTy, absVal);
1056 Value neg = arith::SubIOp::create(rewriter, loc, zeroCstInt, res);
1071 matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
1075 Value inFp = adaptor.getIn();
1078 Type intTy = op.getType();
1079 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1082 loc, llvm::formatv(
"unsupported type: {}", intTy));
1083 unsigned newBitWidth = newTy.getElementTypeBitWidth();
1086 if (
auto vecType = dyn_cast<VectorType>(fpTy))
1095 const llvm::fltSemantics &fSemantics =
1098 auto powBitwidth = llvm::APFloat(fSemantics);
1103 if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(newBitWidth),
1104 false, llvm::RoundingMode::TowardZero) ==
1105 llvm::detail::opStatus::opInexact)
1106 powBitwidth = llvm::APFloat::getInf(fSemantics);
1108 TypedAttr powBitwidthAttr =
1110 if (
auto vecType = dyn_cast<VectorType>(fpTy))
1112 Value powBitwidthFloatCst =
1113 arith::ConstantOp::create(rewriter, loc, powBitwidthAttr);
1115 Value fpDivPowBitwidth =
1116 arith::DivFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
1118 arith::FPToUIOp::create(rewriter, loc, newHalfType, fpDivPowBitwidth);
1121 arith::RemFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
1123 arith::FPToUIOp::create(rewriter, loc, newHalfType, remainder);
1143 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1148 if (!getTypeConverter()->isLegal(op.getType()))
1150 loc, llvm::formatv(
"unsupported truncation result type: {0}",
1158 rewriter.
createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1172 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1183 struct EmulateWideIntPass final
1184 : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1185 using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1187 void runOnOperation()
override {
1188 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1189 signalPassFailure();
1196 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1198 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](
Operation *op) {
1199 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1201 auto opLegalCallback = [&typeConverter](
Operation *op) {
1202 return typeConverter.isLegal(op);
1204 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1205 target.addDynamicallyLegalOp<vector::PrintOp>(opLegalCallback);
1206 target.addDynamicallyLegalDialect<arith::ArithDialect>(opLegalCallback);
1207 target.addLegalDialect<vector::VectorDialect>();
1213 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1219 signalPassFailure();
1229 unsigned widestIntSupportedByTarget)
1230 : maxIntWidth(widestIntSupportedByTarget) {
1231 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1232 "Only power-of-two integers with are supported");
1233 assert(widestIntSupportedByTarget >= 2 &&
"Integer type too narrow");
1239 addConversion([
this](IntegerType ty) -> std::optional<Type> {
1240 unsigned width = ty.getWidth();
1241 if (width <= maxIntWidth)
1245 if (width == 2 * maxIntWidth)
1252 addConversion([
this](VectorType ty) -> std::optional<Type> {
1253 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1257 unsigned width = intTy.getWidth();
1258 if (width <= maxIntWidth)
1262 if (width == 2 * maxIntWidth) {
1263 auto newShape = to_vector(ty.getShape());
1264 newShape.push_back(2);
1273 addConversion([
this](FunctionType ty) -> std::optional<Type> {
1294 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1296 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1297 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1298 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1299 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1300 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
1302 ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1303 ConvertBitwiseBinary<arith::XOrIOp>,
1305 ConvertExtSI, ConvertExtUI, ConvertTruncI,
1307 ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1308 ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1309 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1310 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1311 ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
1312 typeConverter,
patterns.getContext());
Attributes are known-constant values of operations.
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(const WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.