155 Location loc, VectorType resultType,
159 assert(!resultShape.empty() &&
"Result expected to have dimensions");
160 assert(resultShape.back() ==
static_cast<int64_t>(resultComponents.size()) &&
161 "Wrong number of result components");
164 for (
auto [i, component] : llvm::enumerate(resultComponents))
175struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
179 matchAndRewrite(arith::ConstantOp op, OpAdaptor,
180 ConversionPatternRewriter &rewriter)
const override {
182 auto newType = getTypeConverter()->convertType<VectorType>(oldType);
184 return rewriter.notifyMatchFailure(
185 op, llvm::formatv(
"unsupported type: {0}", op.getType()));
187 unsigned newBitWidth = newType.getElementTypeBitWidth();
190 if (
auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
191 auto [low, high] =
getHalves(intAttr.getValue(), newBitWidth);
193 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
197 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
199 getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
200 int64_t numSplatElems = splatAttr.getNumElements();
202 values.reserve(numSplatElems * 2);
203 for (
int64_t i = 0; i < numSplatElems; ++i) {
204 values.push_back(low);
205 values.push_back(high);
209 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
213 if (
auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
214 int64_t numElems = elemsAttr.getNumElements();
216 values.reserve(numElems * 2);
217 for (
const APInt &origVal : elemsAttr.getValues<APInt>()) {
218 auto [low, high] =
getHalves(origVal, newBitWidth);
219 values.push_back(std::move(low));
220 values.push_back(std::move(high));
224 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
228 return rewriter.notifyMatchFailure(op.getLoc(),
229 "unhandled constant attribute");
237struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
241 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter)
const override {
243 Location loc = op->getLoc();
244 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
246 return rewriter.notifyMatchFailure(
247 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
251 auto [lhsElem0, lhsElem1] =
253 auto [rhsElem0, rhsElem1] =
257 arith::AddUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
259 arith::ExtUIOp::create(rewriter, loc, newElemTy, lowSum.getOverflow());
261 Value high0 = arith::AddIOp::create(rewriter, loc, overflowVal, lhsElem1);
262 Value high = arith::AddIOp::create(rewriter, loc, high0, rhsElem1);
266 rewriter.replaceOp(op, resultVec);
276template <
typename BinaryOp>
277struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
278 using OpConversionPattern<BinaryOp>::OpConversionPattern;
279 using OpAdaptor =
typename OpConversionPattern<BinaryOp>::OpAdaptor;
282 matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter)
const override {
284 Location loc = op->getLoc();
285 auto newTy = this->getTypeConverter()->template convertType<VectorType>(
288 return rewriter.notifyMatchFailure(
289 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
291 auto [lhsElem0, lhsElem1] =
293 auto [rhsElem0, rhsElem1] =
296 Value resElem0 = BinaryOp::create(rewriter, loc, lhsElem0, rhsElem0);
297 Value resElem1 = BinaryOp::create(rewriter, loc, lhsElem1, rhsElem1);
300 rewriter.replaceOp(op, resultVec);
311static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
312 using P = arith::CmpIPredicate;
327struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
331 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
332 ConversionPatternRewriter &rewriter)
const override {
333 Location loc = op->getLoc();
335 getTypeConverter()->convertType<VectorType>(op.getLhs().
getType());
337 return rewriter.notifyMatchFailure(
338 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
340 arith::CmpIPredicate highPred = adaptor.getPredicate();
341 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
343 auto [lhsElem0, lhsElem1] =
345 auto [rhsElem0, rhsElem1] =
349 arith::CmpIOp::create(rewriter, loc, lowPred, lhsElem0, rhsElem0);
351 arith::CmpIOp::create(rewriter, loc, highPred, lhsElem1, rhsElem1);
355 case arith::CmpIPredicate::eq: {
356 cmpResult = arith::AndIOp::create(rewriter, loc, lowCmp, highCmp);
359 case arith::CmpIPredicate::ne: {
360 cmpResult = arith::OrIOp::create(rewriter, loc, lowCmp, highCmp);
365 Value highEq = arith::CmpIOp::create(
366 rewriter, loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
368 arith::SelectOp::create(rewriter, loc, highEq, lowCmp, highCmp);
373 assert(cmpResult &&
"Unhandled case");
383struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
387 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter)
const override {
389 Location loc = op->getLoc();
390 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
392 return rewriter.notifyMatchFailure(
393 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
395 auto [lhsElem0, lhsElem1] =
397 auto [rhsElem0, rhsElem1] =
404 arith::MulUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
405 Value mulLowHi = arith::MulIOp::create(rewriter, loc, lhsElem0, rhsElem1);
406 Value mulHiLow = arith::MulIOp::create(rewriter, loc, lhsElem1, rhsElem0);
408 Value resLow = mulLowLow.getLow();
410 arith::AddIOp::create(rewriter, loc, mulLowLow.getHigh(), mulLowHi);
411 resHi = arith::AddIOp::create(rewriter, loc, resHi, mulHiLow);
415 rewriter.replaceOp(op, resultVec);
424struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
428 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
429 ConversionPatternRewriter &rewriter)
const override {
430 Location loc = op->getLoc();
431 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
433 return rewriter.notifyMatchFailure(
434 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
441 Value newOperand =
appendX1Dim(rewriter, loc, adaptor.getIn());
442 Value extended = rewriter.createOrFold<arith::ExtSIOp>(
443 loc, newResultComponentTy, newOperand);
444 Value operandZeroCst =
446 Value signBit = arith::CmpIOp::create(
447 rewriter, loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
449 arith::ExtSIOp::create(rewriter, loc, newResultComponentTy, signBit);
453 rewriter.replaceOp(op, resultVec);
462struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
466 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
467 ConversionPatternRewriter &rewriter)
const override {
468 Location loc = op->getLoc();
469 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
471 return rewriter.notifyMatchFailure(
472 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
478 Value newOperand =
appendX1Dim(rewriter, loc, adaptor.getIn());
479 Value extended = rewriter.createOrFold<arith::ExtUIOp>(
480 loc, newResultComponentTy, newOperand);
483 rewriter.replaceOp(op, newRes);
492template <
typename SourceOp, arith::CmpIPredicate CmpPred>
493struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
494 using OpConversionPattern<SourceOp>::OpConversionPattern;
497 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
498 ConversionPatternRewriter &rewriter)
const override {
499 Location loc = op->getLoc();
501 Type oldTy = op.getType();
502 auto newTy = dyn_cast_or_null<VectorType>(
503 this->getTypeConverter()->convertType(oldTy));
505 return rewriter.notifyMatchFailure(
506 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
511 arith::CmpIOp::create(rewriter, loc, CmpPred, op.getLhs(), op.getRhs());
512 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
522static bool isIndexOrIndexVector(
Type type) {
523 if (isa<IndexType>(type))
526 if (
auto vectorTy = dyn_cast<VectorType>(type))
527 if (isa<IndexType>(vectorTy.getElementType()))
533template <
typename CastOp>
534struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
535 using OpConversionPattern<CastOp>::OpConversionPattern;
538 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter)
const override {
540 Type resultType = op.getType();
541 if (!isIndexOrIndexVector(resultType))
544 Location loc = op.getLoc();
545 Type inType = op.getIn().getType();
547 this->getTypeConverter()->template convertType<VectorType>(inType);
549 return rewriter.notifyMatchFailure(
550 loc, llvm::formatv(
"unsupported type: {0}", inType));
555 rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
560template <
typename CastOp,
typename ExtensionOp>
561struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
562 using OpConversionPattern<CastOp>::OpConversionPattern;
565 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
566 ConversionPatternRewriter &rewriter)
const override {
567 Type inType = op.getIn().getType();
568 if (!isIndexOrIndexVector(inType))
571 Location loc = op.getLoc();
572 auto *typeConverter =
573 this->
template getTypeConverter<arith::WideIntEmulationConverter>();
575 Type resultType = op.getType();
576 auto newTy = typeConverter->template convertType<VectorType>(resultType);
578 return rewriter.notifyMatchFailure(
579 loc, llvm::formatv(
"unsupported type: {0}", resultType));
583 rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
584 if (
auto vecTy = dyn_cast<VectorType>(resultType))
585 narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
589 Value underlyingVal =
590 CastOp::create(rewriter, loc, narrowTy, adaptor.getIn());
591 rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
600struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
604 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
605 ConversionPatternRewriter &rewriter)
const override {
606 Location loc = op->getLoc();
607 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
609 return rewriter.notifyMatchFailure(
610 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
612 auto [trueElem0, trueElem1] =
614 auto [falseElem0, falseElem1] =
616 Value cond =
appendX1Dim(rewriter, loc, adaptor.getCondition());
619 arith::SelectOp::create(rewriter, loc, cond, trueElem0, falseElem0);
621 arith::SelectOp::create(rewriter, loc, cond, trueElem1, falseElem1);
624 rewriter.replaceOp(op, resultVec);
633struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
637 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
638 ConversionPatternRewriter &rewriter)
const override {
639 Location loc = op->getLoc();
641 Type oldTy = op.getType();
642 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
644 return rewriter.notifyMatchFailure(
645 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
649 unsigned newBitWidth = newTy.getElementTypeBitWidth();
651 auto [lhsElem0, lhsElem1] =
683 Value illegalElemShift = arith::CmpIOp::create(
684 rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
687 arith::ShLIOp::create(rewriter, loc, lhsElem0, rhsElem0);
688 Value resElem0 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
689 zeroCst, shiftedElem0);
691 Value cappedShiftAmount = arith::SelectOp::create(
692 rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
693 Value rightShiftAmount =
694 arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
696 arith::ShRUIOp::create(rewriter, loc, lhsElem0, rightShiftAmount);
697 Value overshotShiftAmount =
698 arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
700 arith::ShLIOp::create(rewriter, loc, lhsElem0, overshotShiftAmount);
703 arith::ShLIOp::create(rewriter, loc, lhsElem1, rhsElem0);
704 Value resElem1High = arith::SelectOp::create(
705 rewriter, loc, illegalElemShift, zeroCst, shiftedElem1);
706 Value resElem1Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
707 shiftedLeft, shiftedRight);
709 arith::OrIOp::create(rewriter, loc, resElem1Low, resElem1High);
713 rewriter.replaceOp(op, resultVec);
722struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
726 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
727 ConversionPatternRewriter &rewriter)
const override {
728 Location loc = op->getLoc();
730 Type oldTy = op.getType();
731 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
733 return rewriter.notifyMatchFailure(
734 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
738 unsigned newBitWidth = newTy.getElementTypeBitWidth();
740 auto [lhsElem0, lhsElem1] =
772 Value illegalElemShift = arith::CmpIOp::create(
773 rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
776 arith::ShRUIOp::create(rewriter, loc, lhsElem0, rhsElem0);
777 Value resElem0Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
778 zeroCst, shiftedElem0);
780 arith::ShRUIOp::create(rewriter, loc, lhsElem1, rhsElem0);
781 Value resElem1 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
782 zeroCst, shiftedElem1);
784 Value cappedShiftAmount = arith::SelectOp::create(
785 rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
786 Value leftShiftAmount =
787 arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
789 arith::ShLIOp::create(rewriter, loc, lhsElem1, leftShiftAmount);
790 Value overshotShiftAmount =
791 arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
793 arith::ShRUIOp::create(rewriter, loc, lhsElem1, overshotShiftAmount);
795 Value resElem0High = arith::SelectOp::create(
796 rewriter, loc, illegalElemShift, shiftedRight, shiftedLeft);
798 arith::OrIOp::create(rewriter, loc, resElem0Low, resElem0High);
802 rewriter.replaceOp(op, resultVec);
811struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
815 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
816 ConversionPatternRewriter &rewriter)
const override {
817 Location loc = op->getLoc();
819 Type oldTy = op.getType();
820 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
822 return rewriter.notifyMatchFailure(
823 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
828 Type narrowTy = rhsElem0.
getType();
829 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
835 Value signBit = arith::CmpIOp::create(
836 rewriter, loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
842 Value allSign = arith::ExtSIOp::create(rewriter, loc, oldTy, signBit);
845 Value numNonSignExtBits =
846 arith::SubIOp::create(rewriter, loc, maxShift, rhsElem0);
849 arith::ExtUIOp::create(rewriter, loc, oldTy, numNonSignExtBits);
851 arith::ShLIOp::create(rewriter, loc, allSign, numNonSignExtBits);
855 arith::ShRUIOp::create(rewriter, loc, op.getLhs(), op.getRhs());
856 Value shrsi = arith::OrIOp::create(rewriter, loc, shrui, signBits);
860 Value isNoop = arith::CmpIOp::create(
861 rewriter, loc, arith::CmpIPredicate::eq, rhsElem0, elemZero);
863 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
874struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
878 matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
879 ConversionPatternRewriter &rewriter)
const override {
880 Location loc = op->getLoc();
881 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
883 return rewriter.notifyMatchFailure(
884 loc, llvm::formatv(
"unsupported type: {}", op.getType()));
888 auto [lhsElem0, lhsElem1] =
890 auto [rhsElem0, rhsElem1] =
895 Value low = arith::SubIOp::create(rewriter, loc, lhsElem0, rhsElem0);
897 Value carry0 = arith::CmpIOp::create(
898 rewriter, loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
899 Value carryVal = arith::ExtUIOp::create(rewriter, loc, newElemTy, carry0);
901 Value high0 = arith::SubIOp::create(rewriter, loc, lhsElem1, carryVal);
902 Value high = arith::SubIOp::create(rewriter, loc, high0, rhsElem1);
905 rewriter.replaceOp(op, resultVec);
914struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
918 matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter)
const override {
920 Location loc = op.getLoc();
922 Value in = op.getIn();
924 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
926 return rewriter.notifyMatchFailure(
927 loc, llvm::formatv(
"unsupported type: {0}", oldTy));
936 Value isNeg = arith::CmpIOp::create(rewriter, loc,
937 arith::CmpIPredicate::slt, in, zeroCst);
938 Value neg = arith::SubIOp::create(rewriter, loc, zeroCst, in);
939 Value
abs = arith::SelectOp::create(rewriter, loc, isNeg, neg, in);
941 Value absResult = arith::UIToFPOp::create(rewriter, loc, op.getType(), abs);
942 Value negResult = arith::NegFOp::create(rewriter, loc, absResult);
943 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
953struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
957 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
958 ConversionPatternRewriter &rewriter)
const override {
959 Location loc = op.getLoc();
961 Type oldTy = op.getIn().getType();
962 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
964 return rewriter.notifyMatchFailure(
965 loc, llvm::formatv(
"unsupported type: {0}", oldTy));
966 unsigned newBitWidth = newTy.getElementTypeBitWidth();
988 Value hiEqZero = arith::CmpIOp::create(
989 rewriter, loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
991 Type resultTy = op.getType();
993 Value lowFp = arith::UIToFPOp::create(rewriter, loc, resultTy, lowInt);
994 Value hiFp = arith::UIToFPOp::create(rewriter, loc, resultTy, hiInt);
996 int64_t pow2Int = int64_t(1) << newBitWidth;
998 rewriter.getFloatAttr(resultElemTy,
static_cast<double>(pow2Int));
999 if (
auto vecTy = dyn_cast<VectorType>(resultTy))
1000 pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
1003 arith::ConstantOp::create(rewriter, loc, resultTy, pow2Attr);
1005 Value hiVal = arith::MulFOp::create(rewriter, loc, hiFp, pow2Val);
1006 Value
result = arith::AddFOp::create(rewriter, loc, lowFp, hiVal);
1008 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp,
result);
1017struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
1021 matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
1022 ConversionPatternRewriter &rewriter)
const override {
1023 Location loc = op.getLoc();
1025 Value inFp = adaptor.getIn();
1028 Type intTy = op.getType();
1030 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1032 return rewriter.notifyMatchFailure(
1033 loc, llvm::formatv(
"unsupported type: {}", intTy));
1040 TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);
1041 Value zeroCst = arith::ConstantOp::create(rewriter, loc, zeroAttr);
1046 Value isNeg = arith::CmpFOp::create(
1047 rewriter, loc, arith::CmpFPredicate::OLT, inFp, zeroCst);
1048 Value negInFp = arith::NegFOp::create(rewriter, loc, inFp);
1050 Value absVal = arith::SelectOp::create(rewriter, loc, isNeg, negInFp, inFp);
1053 Value res = arith::FPToUIOp::create(rewriter, loc, intTy, absVal);
1056 Value neg = arith::SubIOp::create(rewriter, loc, zeroCstInt, res);
1058 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
1067struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
1071 matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
1072 ConversionPatternRewriter &rewriter)
const override {
1073 Location loc = op.getLoc();
1075 Value inFp = adaptor.getIn();
1078 Type intTy = op.getType();
1079 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1081 return rewriter.notifyMatchFailure(
1082 loc, llvm::formatv(
"unsupported type: {}", intTy));
1083 unsigned newBitWidth = newTy.getElementTypeBitWidth();
1085 Type newHalfType = IntegerType::get(inFp.
getContext(), newBitWidth);
1086 if (
auto vecType = dyn_cast<VectorType>(fpTy))
1087 newHalfType = VectorType::get(vecType.getShape(), newHalfType);
1095 const llvm::fltSemantics &fSemantics =
1098 auto powBitwidth = llvm::APFloat(fSemantics);
1103 if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(newBitWidth),
1104 false, llvm::RoundingMode::TowardZero) ==
1105 llvm::detail::opStatus::opInexact)
1106 powBitwidth = llvm::APFloat::getInf(fSemantics);
1108 TypedAttr powBitwidthAttr =
1110 if (
auto vecType = dyn_cast<VectorType>(fpTy))
1111 powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
1112 Value powBitwidthFloatCst =
1113 arith::ConstantOp::create(rewriter, loc, powBitwidthAttr);
1115 Value fpDivPowBitwidth =
1116 arith::DivFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
1118 arith::FPToUIOp::create(rewriter, loc, newHalfType, fpDivPowBitwidth);
1121 arith::RemFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
1123 arith::FPToUIOp::create(rewriter, loc, newHalfType, remainder);
1130 rewriter.replaceOp(op, resultVec);
1139struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
1143 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1144 ConversionPatternRewriter &rewriter)
const override {
1145 Location loc = op.getLoc();
1148 if (!getTypeConverter()->isLegal(op.getType()))
1149 return rewriter.notifyMatchFailure(
1150 loc, llvm::formatv(
"unsupported truncation result type: {0}",
1158 rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1159 rewriter.replaceOp(op, truncated);
1168struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
1172 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1173 ConversionPatternRewriter &rewriter)
const override {
1174 rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
1183struct EmulateWideIntPass final
1185 using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1187 void runOnOperation()
override {
1188 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1189 signalPassFailure();
1193 Operation *op = getOperation();
1196 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1197 ConversionTarget
target(*ctx);
1198 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
1199 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1201 auto opLegalCallback = [&typeConverter](Operation *op) {
1202 return typeConverter.isLegal(op);
1204 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1205 target.addDynamicallyLegalOp<vector::PrintOp>(opLegalCallback);
1206 target.addDynamicallyLegalDialect<arith::ArithDialect>(opLegalCallback);
1207 target.addLegalDialect<vector::VectorDialect>();
1210 arith::populateArithWideIntEmulationPatterns(typeConverter,
patterns);
1213 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1219 signalPassFailure();
1229 unsigned widestIntSupportedByTarget)
1230 : maxIntWidth(widestIntSupportedByTarget) {
1231 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1232 "Only power-of-two integers with are supported");
1233 assert(widestIntSupportedByTarget >= 2 &&
"Integer type too narrow");
1236 addConversion([](
Type ty) -> std::optional<Type> {
return ty; });
1239 addConversion([
this](IntegerType ty) -> std::optional<Type> {
1240 unsigned width = ty.getWidth();
1241 if (width <= maxIntWidth)
1245 if (width == 2 * maxIntWidth)
1246 return VectorType::get(2, IntegerType::get(ty.
getContext(), maxIntWidth));
1252 addConversion([
this](VectorType ty) -> std::optional<Type> {
1253 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1257 unsigned width = intTy.getWidth();
1258 if (width <= maxIntWidth)
1262 if (width == 2 * maxIntWidth) {
1263 auto newShape = to_vector(ty.getShape());
1264 newShape.push_back(2);
1265 return VectorType::get(newShape,
1266 IntegerType::get(ty.
getContext(), maxIntWidth));
1273 addConversion([
this](FunctionType ty) -> std::optional<Type> {
1277 if (failed(convertTypes(ty.getInputs(), inputs)))
1281 if (failed(convertTypes(ty.getResults(), results)))
1284 return FunctionType::get(ty.
getContext(), inputs, results);
1294 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1296 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1297 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1298 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1299 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1300 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
1302 ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1303 ConvertBitwiseBinary<arith::XOrIOp>,
1305 ConvertExtSI, ConvertExtUI, ConvertTruncI,
1307 ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1308 ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1309 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1310 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1311 ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
1312 typeConverter,
patterns.getContext());