20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/MathExtras.h"
27 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
41 static std::pair<APInt, APInt>
getHalves(
const APInt &value,
42 unsigned newBitWidth) {
43 APInt low = value.extractBits(newBitWidth, 0);
44 APInt high = value.extractBits(newBitWidth, newBitWidth);
45 return {std::move(low), std::move(high)};
54 if (type.getShape().size() == 1)
55 return type.getElementType();
57 auto newShape = to_vector(type.getShape());
71 assert(lastOffset < shape.back() &&
"Offset out of bounds");
74 if (shape.size() == 1)
75 return rewriter.
create<vector::ExtractOp>(loc, input, lastOffset);
78 offsets.back() = lastOffset;
79 auto sizes = llvm::to_vector(shape);
83 return rewriter.
create<vector::ExtractStridedSliceOp>(loc, input, offsets,
89 static std::pair<Value, Value>
100 auto vecTy = dyn_cast<VectorType>(input.
getType());
106 assert(shape.size() >= 2 &&
"Expected vector with at list two dims");
107 assert(shape.back() == 1 &&
"Expected the last vector dim to be x1");
109 auto newVecTy =
VectorType::get(shape.drop_back(), vecTy.getElementType());
110 return rewriter.
create<vector::ShapeCastOp>(loc, newVecTy, input);
117 auto vecTy = dyn_cast<VectorType>(input.
getType());
122 auto newShape = llvm::to_vector(vecTy.getShape());
123 newShape.push_back(1);
125 return rewriter.
create<vector::ShapeCastOp>(loc, newTy, input);
133 int64_t lastOffset) {
135 assert(lastOffset < shape.back() &&
"Offset out of bounds");
138 if (isa<IntegerType>(source.
getType()))
139 return rewriter.
create<vector::InsertOp>(loc, source, dest, lastOffset);
142 offsets.back() = lastOffset;
144 return rewriter.
create<vector::InsertStridedSliceOp>(loc, source, dest,
155 Location loc, VectorType resultType,
159 assert(!resultShape.empty() &&
"Result expected to have dimensions");
160 assert(resultShape.back() ==
static_cast<int64_t
>(resultComponents.size()) &&
161 "Wrong number of result components");
179 matchAndRewrite(arith::ConstantOp op, OpAdaptor,
181 Type oldType = op.getType();
182 auto newType = getTypeConverter()->convertType<VectorType>(oldType);
185 op, llvm::formatv(
"unsupported type: {0}", op.getType()));
187 unsigned newBitWidth = newType.getElementTypeBitWidth();
190 if (
auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
191 auto [low, high] =
getHalves(intAttr.getValue(), newBitWidth);
197 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
199 getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
200 int64_t numSplatElems = splatAttr.getNumElements();
202 values.reserve(numSplatElems * 2);
203 for (int64_t i = 0; i < numSplatElems; ++i) {
204 values.push_back(low);
205 values.push_back(high);
213 if (
auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
214 int64_t numElems = elemsAttr.getNumElements();
216 values.reserve(numElems * 2);
217 for (
const APInt &origVal : elemsAttr.getValues<APInt>()) {
218 auto [low, high] =
getHalves(origVal, newBitWidth);
219 values.push_back(std::move(low));
220 values.push_back(std::move(high));
229 "unhandled constant attribute");
241 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
244 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
247 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
251 auto [lhsElem0, lhsElem1] =
253 auto [rhsElem0, rhsElem1] =
257 rewriter.
create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
259 rewriter.
create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
261 Value high0 = rewriter.
create<arith::AddIOp>(loc, overflowVal, lhsElem1);
262 Value high = rewriter.
create<arith::AddIOp>(loc, high0, rhsElem1);
276 template <
typename BinaryOp>
282 matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
285 auto newTy = this->getTypeConverter()->template convertType<VectorType>(
289 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
291 auto [lhsElem0, lhsElem1] =
293 auto [rhsElem0, rhsElem1] =
296 Value resElem0 = rewriter.
create<BinaryOp>(loc, lhsElem0, rhsElem0);
297 Value resElem1 = rewriter.
create<BinaryOp>(loc, lhsElem1, rhsElem1);
311 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
312 using P = arith::CmpIPredicate;
331 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
335 getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
338 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
340 arith::CmpIPredicate highPred = adaptor.getPredicate();
341 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
343 auto [lhsElem0, lhsElem1] =
345 auto [rhsElem0, rhsElem1] =
349 rewriter.
create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
351 rewriter.
create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
355 case arith::CmpIPredicate::eq: {
356 cmpResult = rewriter.
create<arith::AndIOp>(loc, lowCmp, highCmp);
359 case arith::CmpIPredicate::ne: {
360 cmpResult = rewriter.
create<arith::OrIOp>(loc, lowCmp, highCmp);
366 loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
368 rewriter.
create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
373 assert(cmpResult &&
"Unhandled case");
387 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
390 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
393 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
395 auto [lhsElem0, lhsElem1] =
397 auto [rhsElem0, rhsElem1] =
404 rewriter.
create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
405 Value mulLowHi = rewriter.
create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
406 Value mulHiLow = rewriter.
create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
408 Value resLow = mulLowLow.getLow();
410 rewriter.
create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
411 resHi = rewriter.
create<arith::AddIOp>(loc, resHi, mulHiLow);
428 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
431 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
434 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
443 loc, newResultComponentTy, newOperand);
444 Value operandZeroCst =
447 loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
449 rewriter.
create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
466 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
469 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
472 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
480 loc, newResultComponentTy, newOperand);
492 template <
typename SourceOp, arith::CmpIPredicate CmpPred>
497 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
501 Type oldTy = op.getType();
502 auto newTy = dyn_cast_or_null<VectorType>(
503 this->getTypeConverter()->convertType(oldTy));
506 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
511 rewriter.
create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
522 static bool isIndexOrIndexVector(
Type type) {
523 if (isa<IndexType>(type))
526 if (
auto vectorTy = dyn_cast<VectorType>(type))
527 if (isa<IndexType>(vectorTy.getElementType()))
533 template <
typename CastOp>
538 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
540 Type resultType = op.getType();
541 if (!isIndexOrIndexVector(resultType))
545 Type inType = op.getIn().getType();
547 this->getTypeConverter()->template convertType<VectorType>(inType);
550 loc, llvm::formatv(
"unsupported type: {0}", inType));
560 template <
typename CastOp,
typename ExtensionOp>
565 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
567 Type inType = op.getIn().getType();
568 if (!isIndexOrIndexVector(inType))
572 auto *typeConverter =
573 this->
template getTypeConverter<arith::WideIntEmulationConverter>();
575 Type resultType = op.getType();
576 auto newTy = typeConverter->template convertType<VectorType>(resultType);
579 loc, llvm::formatv(
"unsupported type: {0}", resultType));
583 rewriter.
getIntegerType(typeConverter->getMaxTargetIntBitWidth());
584 if (
auto vecTy = dyn_cast<VectorType>(resultType))
589 Value underlyingVal =
590 rewriter.
create<CastOp>(loc, narrowTy, adaptor.getIn());
604 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
607 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
610 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
612 auto [trueElem0, trueElem1] =
614 auto [falseElem0, falseElem1] =
619 rewriter.
create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
621 rewriter.
create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
637 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
641 Type oldTy = op.getType();
642 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
645 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
649 unsigned newBitWidth = newTy.getElementTypeBitWidth();
651 auto [lhsElem0, lhsElem1] =
683 Value illegalElemShift = rewriter.
create<arith::CmpIOp>(
684 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
687 rewriter.
create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
688 Value resElem0 = rewriter.
create<arith::SelectOp>(loc, illegalElemShift,
689 zeroCst, shiftedElem0);
691 Value cappedShiftAmount = rewriter.
create<arith::SelectOp>(
692 loc, illegalElemShift, elemBitWidth, rhsElem0);
693 Value rightShiftAmount =
694 rewriter.
create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
696 rewriter.
create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
697 Value overshotShiftAmount =
698 rewriter.
create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
700 rewriter.
create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
703 rewriter.
create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
704 Value resElem1High = rewriter.
create<arith::SelectOp>(
705 loc, illegalElemShift, zeroCst, shiftedElem1);
706 Value resElem1Low = rewriter.
create<arith::SelectOp>(
707 loc, illegalElemShift, shiftedLeft, shiftedRight);
709 rewriter.
create<arith::OrIOp>(loc, resElem1Low, resElem1High);
726 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
730 Type oldTy = op.getType();
731 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
734 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
738 unsigned newBitWidth = newTy.getElementTypeBitWidth();
740 auto [lhsElem0, lhsElem1] =
772 Value illegalElemShift = rewriter.
create<arith::CmpIOp>(
773 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
776 rewriter.
create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
777 Value resElem0Low = rewriter.
create<arith::SelectOp>(loc, illegalElemShift,
778 zeroCst, shiftedElem0);
780 rewriter.
create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
781 Value resElem1 = rewriter.
create<arith::SelectOp>(loc, illegalElemShift,
782 zeroCst, shiftedElem1);
784 Value cappedShiftAmount = rewriter.
create<arith::SelectOp>(
785 loc, illegalElemShift, elemBitWidth, rhsElem0);
786 Value leftShiftAmount =
787 rewriter.
create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
789 rewriter.
create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
790 Value overshotShiftAmount =
791 rewriter.
create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
793 rewriter.
create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
795 Value resElem0High = rewriter.
create<arith::SelectOp>(
796 loc, illegalElemShift, shiftedRight, shiftedLeft);
798 rewriter.
create<arith::OrIOp>(loc, resElem0Low, resElem0High);
815 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
819 Type oldTy = op.getType();
820 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
823 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
829 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
836 loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
842 Value allSign = rewriter.
create<arith::ExtSIOp>(loc, oldTy, signBit);
845 Value numNonSignExtBits =
846 rewriter.
create<arith::SubIOp>(loc, maxShift, rhsElem0);
849 rewriter.
create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
851 rewriter.
create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
855 rewriter.
create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
856 Value shrsi = rewriter.
create<arith::OrIOp>(loc, shrui, signBits);
860 Value isNoop = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
878 matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
881 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
884 loc, llvm::formatv(
"unsupported type: {}", op.getType()));
888 auto [lhsElem0, lhsElem1] =
890 auto [rhsElem0, rhsElem1] =
895 Value low = rewriter.
create<arith::SubIOp>(loc, lhsElem0, rhsElem0);
898 loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
899 Value carryVal = rewriter.
create<arith::ExtUIOp>(loc, newElemTy, carry0);
901 Value high0 = rewriter.
create<arith::SubIOp>(loc, lhsElem1, carryVal);
902 Value high = rewriter.
create<arith::SubIOp>(loc, high0, rhsElem1);
918 matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
922 Value in = op.getIn();
924 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
927 loc, llvm::formatv(
"unsupported type: {0}", oldTy));
936 Value isNeg = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
938 Value neg = rewriter.
create<arith::SubIOp>(loc, zeroCst, in);
939 Value abs = rewriter.
create<arith::SelectOp>(loc, isNeg, neg, in);
941 Value absResult = rewriter.
create<arith::UIToFPOp>(loc, op.getType(),
abs);
942 Value negResult = rewriter.
create<arith::NegFOp>(loc, absResult);
957 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
961 Type oldTy = op.getIn().getType();
962 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
965 loc, llvm::formatv(
"unsupported type: {0}", oldTy));
966 unsigned newBitWidth = newTy.getElementTypeBitWidth();
989 loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
991 Type resultTy = op.getType();
993 Value lowFp = rewriter.
create<arith::UIToFPOp>(loc, resultTy, lowInt);
994 Value hiFp = rewriter.
create<arith::UIToFPOp>(loc, resultTy, hiInt);
996 int64_t pow2Int = int64_t(1) << newBitWidth;
998 rewriter.
getFloatAttr(resultElemTy,
static_cast<double>(pow2Int));
999 if (
auto vecTy = dyn_cast<VectorType>(resultTy))
1002 Value pow2Val = rewriter.
create<arith::ConstantOp>(loc, resultTy, pow2Attr);
1004 Value hiVal = rewriter.
create<arith::MulFOp>(loc, hiFp, pow2Val);
1005 Value result = rewriter.
create<arith::AddFOp>(loc, lowFp, hiVal);
1020 matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
1024 Value inFp = adaptor.getIn();
1027 Type intTy = op.getType();
1029 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1032 loc, llvm::formatv(
"unsupported type: {}", intTy));
1040 Value zeroCst = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
1045 Value isNeg = rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
1047 Value negInFp = rewriter.
create<arith::NegFOp>(loc, inFp);
1049 Value absVal = rewriter.
create<arith::SelectOp>(loc, isNeg, negInFp, inFp);
1052 Value res = rewriter.
create<arith::FPToUIOp>(loc, intTy, absVal);
1055 Value neg = rewriter.
create<arith::SubIOp>(loc, zeroCstInt, res);
1070 matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
1074 Value inFp = adaptor.getIn();
1077 Type intTy = op.getType();
1078 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1081 loc, llvm::formatv(
"unsupported type: {}", intTy));
1082 unsigned newBitWidth = newTy.getElementTypeBitWidth();
1085 if (
auto vecType = dyn_cast<VectorType>(fpTy))
1094 const llvm::fltSemantics &fSemantics =
1097 auto powBitwidth = llvm::APFloat(fSemantics);
1102 if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(newBitWidth),
1103 false, llvm::RoundingMode::TowardZero) ==
1104 llvm::detail::opStatus::opInexact)
1105 powBitwidth = llvm::APFloat::getInf(fSemantics);
1107 TypedAttr powBitwidthAttr =
1109 if (
auto vecType = dyn_cast<VectorType>(fpTy))
1111 Value powBitwidthFloatCst =
1112 rewriter.
create<arith::ConstantOp>(loc, powBitwidthAttr);
1114 Value fpDivPowBitwidth =
1115 rewriter.
create<arith::DivFOp>(loc, inFp, powBitwidthFloatCst);
1117 rewriter.
create<arith::FPToUIOp>(loc, newHalfType, fpDivPowBitwidth);
1120 rewriter.
create<arith::RemFOp>(loc, inFp, powBitwidthFloatCst);
1122 rewriter.
create<arith::FPToUIOp>(loc, newHalfType, remainder);
1142 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1147 if (!getTypeConverter()->isLegal(op.getType()))
1149 loc, llvm::formatv(
"unsupported truncation result type: {0}",
1157 rewriter.
createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1171 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1182 struct EmulateWideIntPass final
1183 : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1184 using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1186 void runOnOperation()
override {
1187 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1188 signalPassFailure();
1195 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1197 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](
Operation *op) {
1198 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1200 auto opLegalCallback = [&typeConverter](
Operation *op) {
1201 return typeConverter.isLegal(op);
1203 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1204 target.addDynamicallyLegalOp<vector::PrintOp>(opLegalCallback);
1205 target.addDynamicallyLegalDialect<arith::ArithDialect>(opLegalCallback);
1206 target.addLegalDialect<vector::VectorDialect>();
1212 signalPassFailure();
1222 unsigned widestIntSupportedByTarget)
1223 : maxIntWidth(widestIntSupportedByTarget) {
1224 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1225 "Only power-of-two integers with are supported");
1226 assert(widestIntSupportedByTarget >= 2 &&
"Integer type too narrow");
1232 addConversion([
this](IntegerType ty) -> std::optional<Type> {
1233 unsigned width = ty.getWidth();
1234 if (width <= maxIntWidth)
1238 if (width == 2 * maxIntWidth)
1245 addConversion([
this](VectorType ty) -> std::optional<Type> {
1246 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1250 unsigned width = intTy.getWidth();
1251 if (width <= maxIntWidth)
1255 if (width == 2 * maxIntWidth) {
1256 auto newShape = to_vector(ty.getShape());
1257 newShape.push_back(2);
1266 addConversion([
this](FunctionType ty) -> std::optional<Type> {
1285 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
1293 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1295 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1296 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1297 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1298 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1299 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
1301 ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1302 ConvertBitwiseBinary<arith::XOrIOp>,
1304 ConvertExtSI, ConvertExtUI, ConvertTruncI,
1306 ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1307 ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1308 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1309 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1310 ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
1311 typeConverter,
patterns.getContext());
Attributes are known-constant values of operations.
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(const WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.