20#include "llvm/ADT/APFloat.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/Support/FormatVariadic.h"
23#include "llvm/Support/MathExtras.h"
27#define GEN_PASS_DEF_ARITHEMULATEWIDEINT
28#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
41static std::pair<APInt, APInt>
getHalves(
const APInt &value,
42 unsigned newBitWidth) {
43 APInt low = value.extractBits(newBitWidth, 0);
44 APInt high = value.extractBits(newBitWidth, newBitWidth);
45 return {std::move(low), std::move(high)};
54 if (type.getShape().size() == 1)
55 return type.getElementType();
57 auto newShape = to_vector(type.getShape());
59 return VectorType::get(newShape, type.getElementType());
71 assert(lastOffset <
shape.back() &&
"Offset out of bounds");
74 if (
shape.size() == 1)
75 return vector::ExtractOp::create(rewriter, loc, input, lastOffset);
78 offsets.back() = lastOffset;
79 auto sizes = llvm::to_vector(
shape);
83 return vector::ExtractStridedSliceOp::create(rewriter, loc, input, offsets,
89static std::pair<Value, Value>
100 auto vecTy = dyn_cast<VectorType>(input.
getType());
106 assert(
shape.size() >= 2 &&
"Expected vector with at list two dims");
107 assert(
shape.back() == 1 &&
"Expected the last vector dim to be x1");
109 auto newVecTy = VectorType::get(
shape.drop_back(), vecTy.getElementType());
110 return vector::ShapeCastOp::create(rewriter, loc, newVecTy, input);
117 auto vecTy = dyn_cast<VectorType>(input.
getType());
122 auto newShape = llvm::to_vector(vecTy.getShape());
123 newShape.push_back(1);
124 auto newTy = VectorType::get(newShape, vecTy.getElementType());
125 return vector::ShapeCastOp::create(rewriter, loc, newTy, input);
135 assert(lastOffset <
shape.back() &&
"Offset out of bounds");
138 if (isa<IntegerType>(source.
getType()))
139 return vector::InsertOp::create(rewriter, loc, source, dest, lastOffset);
142 offsets.back() = lastOffset;
144 return vector::InsertStridedSliceOp::create(rewriter, loc, source, dest,
155 Location loc, VectorType resultType,
159 assert(!resultShape.empty() &&
"Result expected to have dimensions");
160 assert(resultShape.back() ==
static_cast<int64_t>(resultComponents.size()) &&
161 "Wrong number of result components");
164 for (
auto [i, component] : llvm::enumerate(resultComponents))
175struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
179 matchAndRewrite(arith::ConstantOp op, OpAdaptor,
180 ConversionPatternRewriter &rewriter)
const override {
181 Type oldType = op.getType();
182 auto newType = getTypeConverter()->convertType<VectorType>(oldType);
184 return rewriter.notifyMatchFailure(
185 op, llvm::formatv(
"unsupported type: {0}", op.getType()));
187 unsigned newBitWidth = newType.getElementTypeBitWidth();
188 Attribute oldValue = op.getValueAttr();
190 if (
auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
191 auto [low, high] =
getHalves(intAttr.getValue(), newBitWidth);
193 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
197 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
199 getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
200 int64_t numSplatElems = splatAttr.getNumElements();
201 SmallVector<APInt> values;
202 values.reserve(numSplatElems * 2);
203 for (int64_t i = 0; i < numSplatElems; ++i) {
204 values.push_back(low);
205 values.push_back(high);
209 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
213 if (
auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
214 int64_t numElems = elemsAttr.getNumElements();
215 SmallVector<APInt> values;
216 values.reserve(numElems * 2);
217 for (
const APInt &origVal : elemsAttr.getValues<APInt>()) {
218 auto [low, high] =
getHalves(origVal, newBitWidth);
219 values.push_back(std::move(low));
220 values.push_back(std::move(high));
224 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
228 return rewriter.notifyMatchFailure(op.getLoc(),
229 "unhandled constant attribute");
237struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
241 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter)
const override {
243 Location loc = op->getLoc();
244 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
246 return rewriter.notifyMatchFailure(
247 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
251 auto [lhsElem0, lhsElem1] =
253 auto [rhsElem0, rhsElem1] =
257 arith::AddUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
259 arith::ExtUIOp::create(rewriter, loc, newElemTy, lowSum.getOverflow());
261 Value high0 = arith::AddIOp::create(rewriter, loc, overflowVal, lhsElem1);
262 Value high = arith::AddIOp::create(rewriter, loc, high0, rhsElem1);
266 rewriter.replaceOp(op, resultVec);
276template <
typename BinaryOp>
277struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
278 using OpConversionPattern<BinaryOp>::OpConversionPattern;
279 using OpAdaptor =
typename OpConversionPattern<BinaryOp>::OpAdaptor;
282 matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter)
const override {
284 Location loc = op->getLoc();
285 auto newTy = this->getTypeConverter()->template convertType<VectorType>(
288 return rewriter.notifyMatchFailure(
289 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
291 auto [lhsElem0, lhsElem1] =
293 auto [rhsElem0, rhsElem1] =
296 Value resElem0 = BinaryOp::create(rewriter, loc, lhsElem0, rhsElem0);
297 Value resElem1 = BinaryOp::create(rewriter, loc, lhsElem1, rhsElem1);
300 rewriter.replaceOp(op, resultVec);
311static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
312 using P = arith::CmpIPredicate;
327struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
331 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
332 ConversionPatternRewriter &rewriter)
const override {
333 Location loc = op->getLoc();
335 getTypeConverter()->convertType<VectorType>(op.getLhs().
getType());
337 return rewriter.notifyMatchFailure(
338 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
340 arith::CmpIPredicate highPred = adaptor.getPredicate();
341 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
343 auto [lhsElem0, lhsElem1] =
345 auto [rhsElem0, rhsElem1] =
349 arith::CmpIOp::create(rewriter, loc, lowPred, lhsElem0, rhsElem0);
351 arith::CmpIOp::create(rewriter, loc, highPred, lhsElem1, rhsElem1);
355 case arith::CmpIPredicate::eq: {
356 cmpResult = arith::AndIOp::create(rewriter, loc, lowCmp, highCmp);
359 case arith::CmpIPredicate::ne: {
360 cmpResult = arith::OrIOp::create(rewriter, loc, lowCmp, highCmp);
365 Value highEq = arith::CmpIOp::create(
366 rewriter, loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
368 arith::SelectOp::create(rewriter, loc, highEq, lowCmp, highCmp);
373 assert(cmpResult &&
"Unhandled case");
383struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
387 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter)
const override {
389 Location loc = op->getLoc();
390 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
392 return rewriter.notifyMatchFailure(
393 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
395 auto [lhsElem0, lhsElem1] =
397 auto [rhsElem0, rhsElem1] =
404 arith::MulUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
405 Value mulLowHi = arith::MulIOp::create(rewriter, loc, lhsElem0, rhsElem1);
406 Value mulHiLow = arith::MulIOp::create(rewriter, loc, lhsElem1, rhsElem0);
408 Value resLow = mulLowLow.getLow();
410 arith::AddIOp::create(rewriter, loc, mulLowLow.getHigh(), mulLowHi);
411 resHi = arith::AddIOp::create(rewriter, loc, resHi, mulHiLow);
415 rewriter.replaceOp(op, resultVec);
424struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
428 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
429 ConversionPatternRewriter &rewriter)
const override {
430 Location loc = op->getLoc();
431 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
433 return rewriter.notifyMatchFailure(
434 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
441 Value newOperand =
appendX1Dim(rewriter, loc, adaptor.getIn());
442 Value extended = rewriter.createOrFold<arith::ExtSIOp>(
443 loc, newResultComponentTy, newOperand);
444 Value operandZeroCst =
446 Value signBit = arith::CmpIOp::create(
447 rewriter, loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
449 arith::ExtSIOp::create(rewriter, loc, newResultComponentTy, signBit);
453 rewriter.replaceOp(op, resultVec);
462struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
466 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
467 ConversionPatternRewriter &rewriter)
const override {
468 Location loc = op->getLoc();
469 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
471 return rewriter.notifyMatchFailure(
472 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
478 Value newOperand =
appendX1Dim(rewriter, loc, adaptor.getIn());
479 Value extended = rewriter.createOrFold<arith::ExtUIOp>(
480 loc, newResultComponentTy, newOperand);
483 rewriter.replaceOp(op, newRes);
492template <
typename SourceOp, arith::CmpIPredicate CmpPred>
493struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
494 using OpConversionPattern<SourceOp>::OpConversionPattern;
497 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
498 ConversionPatternRewriter &rewriter)
const override {
499 Location loc = op->getLoc();
501 Type oldTy = op.getType();
502 auto newTy = dyn_cast_or_null<VectorType>(
503 this->getTypeConverter()->convertType(oldTy));
505 return rewriter.notifyMatchFailure(
506 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
511 arith::CmpIOp::create(rewriter, loc, CmpPred, op.getLhs(), op.getRhs());
512 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
522static bool isIndexOrIndexVector(
Type type) {
523 if (isa<IndexType>(type))
526 if (
auto vectorTy = dyn_cast<VectorType>(type))
527 if (isa<IndexType>(vectorTy.getElementType()))
533template <
typename CastOp>
534struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
535 using OpConversionPattern<CastOp>::OpConversionPattern;
538 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter)
const override {
540 Type resultType = op.getType();
541 if (!isIndexOrIndexVector(resultType))
544 Location loc = op.getLoc();
545 Type inType = op.getIn().getType();
547 this->getTypeConverter()->template convertType<VectorType>(inType);
549 return rewriter.notifyMatchFailure(
550 loc, llvm::formatv(
"unsupported type: {0}", inType));
555 rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
560template <
typename CastOp,
typename ExtensionOp>
561struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
562 using OpConversionPattern<CastOp>::OpConversionPattern;
565 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
566 ConversionPatternRewriter &rewriter)
const override {
567 Type inType = op.getIn().getType();
568 if (!isIndexOrIndexVector(inType))
571 Location loc = op.getLoc();
572 auto *typeConverter =
573 this->
template getTypeConverter<arith::WideIntEmulationConverter>();
575 Type resultType = op.getType();
576 auto newTy = typeConverter->template convertType<VectorType>(resultType);
578 return rewriter.notifyMatchFailure(
579 loc, llvm::formatv(
"unsupported type: {0}", resultType));
583 rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
584 if (
auto vecTy = dyn_cast<VectorType>(resultType))
585 narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
589 Value underlyingVal =
590 CastOp::create(rewriter, loc, narrowTy, adaptor.getIn());
591 rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
600struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
604 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
605 ConversionPatternRewriter &rewriter)
const override {
606 Location loc = op->getLoc();
607 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
609 return rewriter.notifyMatchFailure(
610 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
612 auto [trueElem0, trueElem1] =
614 auto [falseElem0, falseElem1] =
616 Value cond =
appendX1Dim(rewriter, loc, adaptor.getCondition());
619 arith::SelectOp::create(rewriter, loc, cond, trueElem0, falseElem0);
621 arith::SelectOp::create(rewriter, loc, cond, trueElem1, falseElem1);
624 rewriter.replaceOp(op, resultVec);
633struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
637 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
638 ConversionPatternRewriter &rewriter)
const override {
639 Location loc = op->getLoc();
641 Type oldTy = op.getType();
642 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
644 return rewriter.notifyMatchFailure(
645 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
649 unsigned newBitWidth = newTy.getElementTypeBitWidth();
651 auto [lhsElem0, lhsElem1] =
683 Value illegalElemShift = arith::CmpIOp::create(
684 rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
687 arith::ShLIOp::create(rewriter, loc, lhsElem0, rhsElem0);
688 Value resElem0 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
689 zeroCst, shiftedElem0);
691 Value cappedShiftAmount = arith::SelectOp::create(
692 rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
693 Value rightShiftAmount =
694 arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
696 arith::ShRUIOp::create(rewriter, loc, lhsElem0, rightShiftAmount);
697 Value overshotShiftAmount =
698 arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
700 arith::ShLIOp::create(rewriter, loc, lhsElem0, overshotShiftAmount);
703 arith::ShLIOp::create(rewriter, loc, lhsElem1, rhsElem0);
704 Value resElem1High = arith::SelectOp::create(
705 rewriter, loc, illegalElemShift, zeroCst, shiftedElem1);
706 Value resElem1Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
707 shiftedLeft, shiftedRight);
709 arith::OrIOp::create(rewriter, loc, resElem1Low, resElem1High);
713 rewriter.replaceOp(op, resultVec);
722struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
726 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
727 ConversionPatternRewriter &rewriter)
const override {
728 Location loc = op->getLoc();
730 Type oldTy = op.getType();
731 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
733 return rewriter.notifyMatchFailure(
734 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
738 unsigned newBitWidth = newTy.getElementTypeBitWidth();
740 auto [lhsElem0, lhsElem1] =
772 Value illegalElemShift = arith::CmpIOp::create(
773 rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
776 arith::ShRUIOp::create(rewriter, loc, lhsElem0, rhsElem0);
777 Value resElem0Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
778 zeroCst, shiftedElem0);
780 arith::ShRUIOp::create(rewriter, loc, lhsElem1, rhsElem0);
781 Value resElem1 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
782 zeroCst, shiftedElem1);
784 Value cappedShiftAmount = arith::SelectOp::create(
785 rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
786 Value leftShiftAmount =
787 arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
789 arith::ShLIOp::create(rewriter, loc, lhsElem1, leftShiftAmount);
790 Value overshotShiftAmount =
791 arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
793 arith::ShRUIOp::create(rewriter, loc, lhsElem1, overshotShiftAmount);
795 Value resElem0High = arith::SelectOp::create(
796 rewriter, loc, illegalElemShift, shiftedRight, shiftedLeft);
798 arith::OrIOp::create(rewriter, loc, resElem0Low, resElem0High);
802 rewriter.replaceOp(op, resultVec);
811struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
815 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
816 ConversionPatternRewriter &rewriter)
const override {
817 Location loc = op->getLoc();
819 Type oldTy = op.getType();
820 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
822 return rewriter.notifyMatchFailure(
823 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
828 Type narrowTy = rhsElem0.
getType();
829 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
835 Value signBit = arith::CmpIOp::create(
836 rewriter, loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
842 Value allSign = arith::ExtSIOp::create(rewriter, loc, oldTy, signBit);
845 Value numNonSignExtBits =
846 arith::SubIOp::create(rewriter, loc, maxShift, rhsElem0);
849 arith::ExtUIOp::create(rewriter, loc, oldTy, numNonSignExtBits);
851 arith::ShLIOp::create(rewriter, loc, allSign, numNonSignExtBits);
855 arith::ShRUIOp::create(rewriter, loc, op.getLhs(), op.getRhs());
856 Value shrsi = arith::OrIOp::create(rewriter, loc, shrui, signBits);
860 Value isNoop = arith::CmpIOp::create(
861 rewriter, loc, arith::CmpIPredicate::eq, rhsElem0, elemZero);
863 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
874struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
878 matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
879 ConversionPatternRewriter &rewriter)
const override {
880 Location loc = op->getLoc();
881 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
883 return rewriter.notifyMatchFailure(
884 loc, llvm::formatv(
"unsupported type: {}", op.getType()));
888 auto [lhsElem0, lhsElem1] =
890 auto [rhsElem0, rhsElem1] =
895 Value low = arith::SubIOp::create(rewriter, loc, lhsElem0, rhsElem0);
897 Value carry0 = arith::CmpIOp::create(
898 rewriter, loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
899 Value carryVal = arith::ExtUIOp::create(rewriter, loc, newElemTy, carry0);
901 Value high0 = arith::SubIOp::create(rewriter, loc, lhsElem1, carryVal);
902 Value high = arith::SubIOp::create(rewriter, loc, high0, rhsElem1);
905 rewriter.replaceOp(op, resultVec);
914struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
918 matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter)
const override {
920 Location loc = op.getLoc();
922 Value in = op.getIn();
924 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
926 return rewriter.notifyMatchFailure(
927 loc, llvm::formatv(
"unsupported type: {0}", oldTy));
936 Value isNeg = arith::CmpIOp::create(rewriter, loc,
937 arith::CmpIPredicate::slt, in, zeroCst);
938 Value neg = arith::SubIOp::create(rewriter, loc, zeroCst, in);
939 Value
abs = arith::SelectOp::create(rewriter, loc, isNeg, neg, in);
941 Value absResult = arith::UIToFPOp::create(rewriter, loc, op.getType(), abs);
942 Value negResult = arith::NegFOp::create(rewriter, loc, absResult);
943 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
953struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
957 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
958 ConversionPatternRewriter &rewriter)
const override {
959 Location loc = op.getLoc();
961 Type oldTy = op.getIn().getType();
962 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
964 return rewriter.notifyMatchFailure(
965 loc, llvm::formatv(
"unsupported type: {0}", oldTy));
966 unsigned newBitWidth = newTy.getElementTypeBitWidth();
988 Value hiEqZero = arith::CmpIOp::create(
989 rewriter, loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
991 Type resultTy = op.getType();
993 Value lowFp = arith::UIToFPOp::create(rewriter, loc, resultTy, lowInt);
994 Value hiFp = arith::UIToFPOp::create(rewriter, loc, resultTy, hiInt);
996 int64_t pow2Int = int64_t(1) << newBitWidth;
998 rewriter.getFloatAttr(resultElemTy,
static_cast<double>(pow2Int));
999 if (
auto vecTy = dyn_cast<VectorType>(resultTy))
1000 pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
1003 arith::ConstantOp::create(rewriter, loc, resultTy, pow2Attr);
1005 Value hiVal = arith::MulFOp::create(rewriter, loc, hiFp, pow2Val);
1006 Value
result = arith::AddFOp::create(rewriter, loc, lowFp, hiVal);
1008 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp,
result);
1017struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
1021 matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
1022 ConversionPatternRewriter &rewriter)
const override {
1023 Location loc = op.getLoc();
1025 Value inFp = adaptor.getIn();
1028 Type intTy = op.getType();
1030 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1032 return rewriter.notifyMatchFailure(
1033 loc, llvm::formatv(
"unsupported type: {}", intTy));
1040 TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);
1041 Value zeroCst = arith::ConstantOp::create(rewriter, loc, zeroAttr);
1046 Value isNeg = arith::CmpFOp::create(
1047 rewriter, loc, arith::CmpFPredicate::OLT, inFp, zeroCst);
1048 Value negInFp = arith::NegFOp::create(rewriter, loc, inFp);
1050 Value absVal = arith::SelectOp::create(rewriter, loc, isNeg, negInFp, inFp);
1053 Value res = arith::FPToUIOp::create(rewriter, loc, intTy, absVal);
1056 Value neg = arith::SubIOp::create(rewriter, loc, zeroCstInt, res);
1058 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
1067struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
1071 matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
1072 ConversionPatternRewriter &rewriter)
const override {
1073 Location loc = op.getLoc();
1075 Value inFp = adaptor.getIn();
1078 Type intTy = op.getType();
1079 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1081 return rewriter.notifyMatchFailure(
1082 loc, llvm::formatv(
"unsupported type: {}", intTy));
1083 unsigned newBitWidth = newTy.getElementTypeBitWidth();
1085 Type newHalfType = IntegerType::get(inFp.
getContext(), newBitWidth);
1086 if (
auto vecType = dyn_cast<VectorType>(fpTy))
1087 newHalfType = VectorType::get(vecType.getShape(), newHalfType);
1095 const llvm::fltSemantics &fSemantics =
1098 auto powBitwidth = llvm::APFloat(fSemantics);
1103 if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(newBitWidth),
1104 false, llvm::RoundingMode::TowardZero) ==
1105 llvm::detail::opStatus::opInexact)
1106 powBitwidth = llvm::APFloat::getInf(fSemantics);
1108 TypedAttr powBitwidthAttr =
1110 if (
auto vecType = dyn_cast<VectorType>(fpTy))
1111 powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
1112 Value powBitwidthFloatCst =
1113 arith::ConstantOp::create(rewriter, loc, powBitwidthAttr);
1115 Value fpDivPowBitwidth =
1116 arith::DivFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
1118 arith::FPToUIOp::create(rewriter, loc, newHalfType, fpDivPowBitwidth);
1121 arith::RemFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
1123 arith::FPToUIOp::create(rewriter, loc, newHalfType, remainder);
1130 rewriter.replaceOp(op, resultVec);
1139struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
1143 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1144 ConversionPatternRewriter &rewriter)
const override {
1145 Location loc = op.getLoc();
1148 if (!getTypeConverter()->isLegal(op.getType()))
1149 return rewriter.notifyMatchFailure(
1150 loc, llvm::formatv(
"unsupported truncation result type: {0}",
1158 rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1159 rewriter.replaceOp(op, truncated);
1168struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
1172 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1173 ConversionPatternRewriter &rewriter)
const override {
1174 rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
1183struct EmulateWideIntPass final
1184 : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1185 using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1187 void runOnOperation()
override {
1188 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1189 signalPassFailure();
1193 Operation *op = getOperation();
1196 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1197 ConversionTarget
target(*ctx);
1198 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
1199 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1201 auto opLegalCallback = [&typeConverter](Operation *op) {
1202 return typeConverter.isLegal(op);
1204 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1205 target.addDynamicallyLegalOp<vector::PrintOp>(opLegalCallback);
1206 target.addDynamicallyLegalDialect<arith::ArithDialect>(opLegalCallback);
1207 target.addLegalDialect<vector::VectorDialect>();
1210 arith::populateArithWideIntEmulationPatterns(typeConverter,
patterns);
1213 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1219 signalPassFailure();
1229 unsigned widestIntSupportedByTarget)
1230 : maxIntWidth(widestIntSupportedByTarget) {
1231 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1232 "Only power-of-two integers with are supported");
1233 assert(widestIntSupportedByTarget >= 2 &&
"Integer type too narrow");
1236 addConversion([](
Type ty) -> std::optional<Type> {
return ty; });
1239 addConversion([
this](IntegerType ty) -> std::optional<Type> {
1240 unsigned width = ty.getWidth();
1241 if (width <= maxIntWidth)
1245 if (width == 2 * maxIntWidth)
1246 return VectorType::get(2, IntegerType::get(ty.
getContext(), maxIntWidth));
1252 addConversion([
this](VectorType ty) -> std::optional<Type> {
1253 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1257 unsigned width = intTy.getWidth();
1258 if (width <= maxIntWidth)
1262 if (width == 2 * maxIntWidth) {
1263 auto newShape = to_vector(ty.getShape());
1264 newShape.push_back(2);
1265 return VectorType::get(newShape,
1266 IntegerType::get(ty.
getContext(), maxIntWidth));
1273 addConversion([
this](FunctionType ty) -> std::optional<Type> {
1277 if (failed(convertTypes(ty.getInputs(), inputs)))
1281 if (failed(convertTypes(ty.getResults(), results)))
1284 return FunctionType::get(ty.
getContext(), inputs, results);
1294 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1296 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1297 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1298 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1299 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1300 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
1302 ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1303 ConvertBitwiseBinary<arith::XOrIOp>,
1305 ConvertExtSI, ConvertExtUI, ConvertTruncI,
1307 ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1308 ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1309 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1310 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1311 ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
1312 typeConverter,
patterns.getContext());
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext()
Return the context this operation is associated with.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(const WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
Fraction abs(const Fraction &f)
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns