31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/Support/Casting.h"
43 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
46 auto constant = rewriter.
create<LLVM::ConstantOp>(
49 return rewriter.
create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
52 return rewriter.
create<LLVM::InsertValueOp>(loc, val1, val2, pos);
58 Value val,
Type llvmType, int64_t rank, int64_t pos) {
61 auto constant = rewriter.
create<LLVM::ConstantOp>(
64 return rewriter.
create<LLVM::ExtractElementOp>(loc, llvmType, val,
67 return rewriter.
create<LLVM::ExtractValueOp>(loc, val, pos);
72 VectorType vectorType,
unsigned &align) {
74 if (!convertedVectorTy)
77 llvm::LLVMContext llvmContext;
87 MemRefType memrefType,
unsigned &align) {
88 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
94 llvm::LLVMContext llvmContext;
107 VectorType vectorType,
108 MemRefType memrefType,
unsigned &align,
109 bool useVectorAlignment) {
110 if (useVectorAlignment) {
125 if (!memRefType.isLastDimUnitStride())
135 MemRefType memRefType,
Value llvmMemref,
Value base,
136 Value index, VectorType vectorType) {
138 "unsupported memref type");
139 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
143 vectorType.getScalableDims()[0]);
144 return rewriter.
create<LLVM::GEPOp>(
145 loc, ptrsType, typeConverter.
convertType(memRefType.getElementType()),
153 if (
auto attr = dyn_cast<Attribute>(foldResult)) {
154 auto intAttr = cast<IntegerAttr>(attr);
155 return builder.
create<LLVM::ConstantOp>(loc, intAttr).getResult();
158 return cast<Value>(foldResult);
164 using VectorScaleOpConversion =
168 class VectorBitCastOpConversion
174 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177 VectorType resultTy = bitCastOp.getResultVectorType();
178 if (resultTy.getRank() > 1)
180 Type newResultTy = typeConverter->convertType(resultTy);
182 adaptor.getOperands()[0]);
189 class VectorMatmulOpConversion
195 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
198 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
199 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
200 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
207 class VectorFlatTransposeOpConversion
213 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
216 transOp, typeConverter->convertType(transOp.getRes().getType()),
217 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
225 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
226 vector::LoadOpAdaptor adaptor,
227 VectorType vectorTy,
Value ptr,
unsigned align,
231 loadOp.getNontemporal());
234 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
235 vector::MaskedLoadOpAdaptor adaptor,
236 VectorType vectorTy,
Value ptr,
unsigned align,
239 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
242 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
243 vector::StoreOpAdaptor adaptor,
244 VectorType vectorTy,
Value ptr,
unsigned align,
248 storeOp.getNontemporal());
251 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
252 vector::MaskedStoreOpAdaptor adaptor,
253 VectorType vectorTy,
Value ptr,
unsigned align,
256 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
261 template <
class LoadOrStoreOp>
267 useVectorAlignment(useVectorAlign) {}
271 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
272 typename LoadOrStoreOp::Adaptor adaptor,
275 VectorType vectorTy = loadOrStoreOp.getVectorType();
276 if (vectorTy.getRank() > 1)
279 auto loc = loadOrStoreOp->getLoc();
280 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
285 memRefTy, align, useVectorAlignment)))
287 "could not resolve alignment");
290 auto vtype = cast<VectorType>(
291 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
292 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
293 adaptor.getIndices(), rewriter);
294 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
304 const bool useVectorAlignment;
308 class VectorGatherOpConversion
314 useVectorAlignment(useVectorAlign) {}
318 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
321 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
322 assert(memRefType &&
"The base should be bufferized");
327 VectorType vType = gather.getVectorType();
328 if (vType.getRank() > 1) {
330 gather,
"only 1-D vectors can be lowered to LLVM");
336 memRefType, align, useVectorAlignment)))
340 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
341 adaptor.getIndices(), rewriter);
342 Value base = adaptor.getBase();
344 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
345 base, ptr, adaptor.getIndexVec(), vType);
349 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
359 const bool useVectorAlignment;
363 class VectorScatterOpConversion
369 useVectorAlignment(useVectorAlign) {}
374 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
376 auto loc = scatter->getLoc();
377 MemRefType memRefType = scatter.getMemRefType();
382 VectorType vType = scatter.getVectorType();
383 if (vType.getRank() > 1) {
385 scatter,
"only 1-D vectors can be lowered to LLVM");
391 memRefType, align, useVectorAlignment)))
393 "could not resolve alignment");
396 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397 adaptor.getIndices(), rewriter);
399 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
400 adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
404 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
414 const bool useVectorAlignment;
418 class VectorExpandLoadOpConversion
424 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
426 auto loc = expand->getLoc();
427 MemRefType memRefType = expand.getMemRefType();
430 auto vtype = typeConverter->convertType(expand.getVectorType());
431 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
432 adaptor.getIndices(), rewriter);
435 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
441 class VectorCompressStoreOpConversion
447 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
449 auto loc = compress->getLoc();
450 MemRefType memRefType = compress.getMemRefType();
453 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
454 adaptor.getIndices(), rewriter);
457 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
463 class ReductionNeutralZero {};
464 class ReductionNeutralIntOne {};
465 class ReductionNeutralFPOne {};
466 class ReductionNeutralAllOnes {};
467 class ReductionNeutralSIntMin {};
468 class ReductionNeutralUIntMin {};
469 class ReductionNeutralSIntMax {};
470 class ReductionNeutralUIntMax {};
471 class ReductionNeutralFPMin {};
472 class ReductionNeutralFPMax {};
475 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
478 return rewriter.
create<LLVM::ConstantOp>(loc, llvmType,
483 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
486 return rewriter.
create<LLVM::ConstantOp>(
491 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
494 return rewriter.
create<LLVM::ConstantOp>(
499 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
502 return rewriter.
create<LLVM::ConstantOp>(
509 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
512 return rewriter.
create<LLVM::ConstantOp>(
519 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
522 return rewriter.
create<LLVM::ConstantOp>(
529 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
532 return rewriter.
create<LLVM::ConstantOp>(
539 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
542 return rewriter.
create<LLVM::ConstantOp>(
549 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
552 auto floatType = cast<FloatType>(llvmType);
553 return rewriter.
create<LLVM::ConstantOp>(
556 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
561 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
564 auto floatType = cast<FloatType>(llvmType);
565 return rewriter.
create<LLVM::ConstantOp>(
568 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
574 template <
class ReductionNeutral>
581 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
590 VectorType vType = cast<VectorType>(llvmType);
591 auto vShape = vType.getShape();
592 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
594 Value baseVecLength = rewriter.
create<LLVM::ConstantOp>(
598 if (!vType.getScalableDims()[0])
599 return baseVecLength;
602 Value vScale = rewriter.
create<vector::VectorScaleOp>(loc);
605 Value scalableVecLength =
606 rewriter.
create<arith::MulIOp>(loc, baseVecLength, vScale);
607 return scalableVecLength;
614 template <
class LLVMRedIntrinOp,
class ScalarOp>
615 static Value createIntegerReductionArithmeticOpLowering(
619 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
622 result = rewriter.
create<ScalarOp>(loc, accumulator, result);
630 template <
class LLVMRedIntrinOp>
631 static Value createIntegerReductionComparisonOpLowering(
633 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
634 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
637 rewriter.
create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
638 result = rewriter.
create<LLVM::SelectOp>(loc, cmp, accumulator, result);
644 template <
typename Source>
645 struct VectorToScalarMapper;
647 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
648 using Type = LLVM::MaximumOp;
651 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
652 using Type = LLVM::MinimumOp;
655 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
656 using Type = LLVM::MaxNumOp;
659 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
660 using Type = LLVM::MinNumOp;
664 template <
class LLVMRedIntrinOp>
665 static Value createFPReductionComparisonOpLowering(
667 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
669 rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
673 rewriter.
create<
typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
674 loc, result, accumulator);
681 class MaskNeutralFMaximum {};
682 class MaskNeutralFMinimum {};
686 getMaskNeutralValue(MaskNeutralFMaximum,
687 const llvm::fltSemantics &floatSemantics) {
688 return llvm::APFloat::getSmallest(floatSemantics,
true);
692 getMaskNeutralValue(MaskNeutralFMinimum,
693 const llvm::fltSemantics &floatSemantics) {
694 return llvm::APFloat::getLargest(floatSemantics,
false);
698 template <
typename MaskNeutral>
702 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
703 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
705 return rewriter.
create<LLVM::ConstantOp>(loc, vectorType, denseValue);
712 template <
class LLVMRedIntrinOp,
class MaskNeutral>
717 Value mask, LLVM::FastmathFlagsAttr fmf) {
718 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
719 rewriter, loc, llvmType, vectorOperand.
getType());
720 const Value selectedVectorByMask = rewriter.
create<LLVM::SelectOp>(
721 loc, mask, vectorOperand, vectorMaskNeutral);
722 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
723 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
726 template <
class LLVMRedIntrinOp,
class ReductionNeutral>
730 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
731 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
732 llvmType, accumulator);
733 return rewriter.
create<LLVMRedIntrinOp>(loc, llvmType,
741 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
746 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
747 llvmType, accumulator);
748 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
753 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
754 static Value lowerPredicatedReductionWithStartValue(
757 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
758 llvmType, accumulator);
760 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
761 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
763 vectorOperand, mask, vectorLength);
766 template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
767 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
768 static Value lowerPredicatedReductionWithStartValue(
772 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
773 IntReductionNeutral>(
774 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
777 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
779 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
783 class VectorReductionOpConversion
787 bool reassociateFPRed)
789 reassociateFPReductions(reassociateFPRed) {}
792 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
794 auto kind = reductionOp.getKind();
795 Type eltType = reductionOp.getDest().getType();
796 Type llvmType = typeConverter->convertType(eltType);
797 Value operand = adaptor.getVector();
798 Value acc = adaptor.getAcc();
799 Location loc = reductionOp.getLoc();
805 case vector::CombiningKind::ADD:
807 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
809 rewriter, loc, llvmType, operand, acc);
811 case vector::CombiningKind::MUL:
813 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
815 rewriter, loc, llvmType, operand, acc);
818 result = createIntegerReductionComparisonOpLowering<
819 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
820 LLVM::ICmpPredicate::ule);
822 case vector::CombiningKind::MINSI:
823 result = createIntegerReductionComparisonOpLowering<
824 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
825 LLVM::ICmpPredicate::sle);
827 case vector::CombiningKind::MAXUI:
828 result = createIntegerReductionComparisonOpLowering<
829 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
830 LLVM::ICmpPredicate::uge);
832 case vector::CombiningKind::MAXSI:
833 result = createIntegerReductionComparisonOpLowering<
834 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
835 LLVM::ICmpPredicate::sge);
837 case vector::CombiningKind::AND:
839 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
841 rewriter, loc, llvmType, operand, acc);
843 case vector::CombiningKind::OR:
845 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
847 rewriter, loc, llvmType, operand, acc);
849 case vector::CombiningKind::XOR:
851 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
853 rewriter, loc, llvmType, operand, acc);
863 if (!isa<FloatType>(eltType))
866 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
868 reductionOp.getContext(),
871 reductionOp.getContext(),
872 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
873 : LLVM::FastmathFlags::none));
877 if (
kind == vector::CombiningKind::ADD) {
878 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
879 ReductionNeutralZero>(
880 rewriter, loc, llvmType, operand, acc, fmf);
881 }
else if (
kind == vector::CombiningKind::MUL) {
882 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
883 ReductionNeutralFPOne>(
884 rewriter, loc, llvmType, operand, acc, fmf);
885 }
else if (
kind == vector::CombiningKind::MINIMUMF) {
887 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
888 rewriter, loc, llvmType, operand, acc, fmf);
889 }
else if (
kind == vector::CombiningKind::MAXIMUMF) {
891 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
892 rewriter, loc, llvmType, operand, acc, fmf);
893 }
else if (
kind == vector::CombiningKind::MINNUMF) {
894 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
895 rewriter, loc, llvmType, operand, acc, fmf);
896 }
else if (
kind == vector::CombiningKind::MAXNUMF) {
897 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
898 rewriter, loc, llvmType, operand, acc, fmf);
908 const bool reassociateFPReductions;
919 template <
class MaskedOp>
920 class VectorMaskOpConversionBase
926 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
929 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
932 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
936 virtual LogicalResult
937 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
938 vector::MaskableOpInterface maskableOp,
942 class MaskedReductionOpConversion
943 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
946 using VectorMaskOpConversionBase<
947 vector::ReductionOp>::VectorMaskOpConversionBase;
949 LogicalResult matchAndRewriteMaskableOp(
950 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
952 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
953 auto kind = reductionOp.getKind();
954 Type eltType = reductionOp.getDest().getType();
955 Type llvmType = typeConverter->convertType(eltType);
956 Value operand = reductionOp.getVector();
957 Value acc = reductionOp.getAcc();
958 Location loc = reductionOp.getLoc();
960 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
962 reductionOp.getContext(),
967 case vector::CombiningKind::ADD:
968 result = lowerPredicatedReductionWithStartValue<
969 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
970 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
973 case vector::CombiningKind::MUL:
974 result = lowerPredicatedReductionWithStartValue<
975 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
976 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
980 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
981 ReductionNeutralUIntMax>(
982 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
984 case vector::CombiningKind::MINSI:
985 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
986 ReductionNeutralSIntMax>(
987 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
989 case vector::CombiningKind::MAXUI:
990 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
991 ReductionNeutralUIntMin>(
992 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
994 case vector::CombiningKind::MAXSI:
995 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
996 ReductionNeutralSIntMin>(
997 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
999 case vector::CombiningKind::AND:
1000 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
1001 ReductionNeutralAllOnes>(
1002 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1004 case vector::CombiningKind::OR:
1005 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
1006 ReductionNeutralZero>(
1007 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1009 case vector::CombiningKind::XOR:
1010 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
1011 ReductionNeutralZero>(
1012 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1014 case vector::CombiningKind::MINNUMF:
1015 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1016 ReductionNeutralFPMax>(
1017 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1019 case vector::CombiningKind::MAXNUMF:
1020 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1021 ReductionNeutralFPMin>(
1022 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1024 case CombiningKind::MAXIMUMF:
1025 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1026 MaskNeutralFMaximum>(
1027 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1029 case CombiningKind::MINIMUMF:
1030 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1031 MaskNeutralFMinimum>(
1032 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1037 rewriter.replaceOp(maskOp, result);
1042 class VectorShuffleOpConversion
1048 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1050 auto loc = shuffleOp->getLoc();
1051 auto v1Type = shuffleOp.getV1VectorType();
1052 auto v2Type = shuffleOp.getV2VectorType();
1053 auto vectorType = shuffleOp.getResultVectorType();
1054 Type llvmType = typeConverter->convertType(vectorType);
1062 int64_t rank = vectorType.getRank();
1064 bool wellFormed0DCase =
1065 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1066 bool wellFormedNDCase =
1067 v1Type.getRank() == rank && v2Type.getRank() == rank;
1068 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1073 if (rank <= 1 && v1Type == v2Type) {
1074 Value llvmShuffleOp = rewriter.
create<LLVM::ShuffleVectorOp>(
1075 loc, adaptor.getV1(), adaptor.getV2(),
1076 llvm::to_vector_of<int32_t>(mask));
1077 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
1082 int64_t v1Dim = v1Type.getDimSize(0);
1084 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1085 eltType = arrayType.getElementType();
1087 eltType = cast<VectorType>(llvmType).getElementType();
1088 Value insert = rewriter.
create<LLVM::PoisonOp>(loc, llvmType);
1090 for (int64_t extPos : mask) {
1091 Value value = adaptor.getV1();
1092 if (extPos >= v1Dim) {
1094 value = adaptor.getV2();
1097 eltType, rank, extPos);
1098 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1099 llvmType, rank, insPos++);
1106 class VectorExtractElementOpConversion
1113 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1115 auto vectorType = extractEltOp.getSourceVectorType();
1116 auto llvmType = typeConverter->convertType(vectorType.getElementType());
1122 if (vectorType.getRank() == 0) {
1123 Location loc = extractEltOp.getLoc();
1125 auto zero = rewriter.
create<LLVM::ConstantOp>(
1126 loc, typeConverter->convertType(idxType),
1129 extractEltOp, llvmType, adaptor.getVector(), zero);
1134 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1139 class VectorExtractOpConversion
1145 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1147 auto loc = extractOp->getLoc();
1148 auto resultType = extractOp.getResult().getType();
1149 auto llvmResultType = typeConverter->convertType(resultType);
1151 if (!llvmResultType)
1155 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1169 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1173 bool extractsScalar =
static_cast<int64_t
>(positionVec.size()) ==
1174 extractOp.getSourceVectorType().getRank();
1178 if (extractOp.getSourceVectorType().getRank() == 0) {
1180 positionVec.push_back(rewriter.
getZeroAttr(idxType));
1183 Value extracted = adaptor.getVector();
1184 if (extractsAggregate) {
1186 if (extractsScalar) {
1190 position = position.drop_back();
1193 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1196 extracted = rewriter.
create<LLVM::ExtractValueOp>(
1200 if (extractsScalar) {
1201 extracted = rewriter.
create<LLVM::ExtractElementOp>(
1202 loc, extracted,
getAsLLVMValue(rewriter, loc, positionVec.back()));
1205 rewriter.
replaceOp(extractOp, extracted);
1229 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1231 VectorType vType = fmaOp.getVectorType();
1232 if (vType.getRank() > 1)
1236 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1241 class VectorInsertElementOpConversion
1247 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1249 auto vectorType = insertEltOp.getDestVectorType();
1250 auto llvmType = typeConverter->convertType(vectorType);
1256 if (vectorType.getRank() == 0) {
1257 Location loc = insertEltOp.getLoc();
1259 auto zero = rewriter.
create<LLVM::ConstantOp>(
1260 loc, typeConverter->convertType(idxType),
1263 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1268 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1269 adaptor.getPosition());
1274 class VectorInsertOpConversion
1280 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1282 auto loc = insertOp->getLoc();
1283 auto destVectorType = insertOp.getDestVectorType();
1284 auto llvmResultType = typeConverter->convertType(destVectorType);
1286 if (!llvmResultType)
1290 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1312 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1314 bool insertIntoInnermostDim =
1315 static_cast<int64_t
>(positionVec.size()) == destVectorType.getRank();
1318 positionVec.begin(),
1319 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1321 if (destVectorType.getRank() == 0) {
1325 positionOfScalarWithin1DVector = rewriter.
getZeroAttr(idxType);
1326 }
else if (insertIntoInnermostDim) {
1327 positionOfScalarWithin1DVector = positionVec.back();
1333 Value sourceAggregate = adaptor.getValueToStore();
1334 if (insertIntoInnermostDim) {
1337 if (isNestedAggregate) {
1340 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1341 llvm::IsaPred<Attribute>)) {
1345 sourceAggregate = rewriter.
create<LLVM::ExtractValueOp>(
1346 loc, adaptor.getDest(),
1351 sourceAggregate = adaptor.getDest();
1354 sourceAggregate = rewriter.
create<LLVM::InsertElementOp>(
1355 loc, sourceAggregate.
getType(), sourceAggregate,
1356 adaptor.getValueToStore(),
1360 Value result = sourceAggregate;
1361 if (isNestedAggregate) {
1362 result = rewriter.
create<LLVM::InsertValueOp>(
1363 loc, adaptor.getDest(), sourceAggregate,
1373 struct VectorScalableInsertOpLowering
1379 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1382 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1388 struct VectorScalableExtractOpLowering
1394 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1397 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1398 adaptor.getSource(), adaptor.getPos());
1431 setHasBoundedRewriteRecursion();
1434 LogicalResult matchAndRewrite(FMAOp op,
1436 auto vType = op.getVectorType();
1437 if (vType.getRank() < 2)
1440 auto loc = op.getLoc();
1441 auto elemType = vType.getElementType();
1444 Value desc = rewriter.
create<vector::SplatOp>(loc, vType, zero);
1445 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1446 Value extrLHS = rewriter.
create<ExtractOp>(loc, op.getLhs(), i);
1447 Value extrRHS = rewriter.
create<ExtractOp>(loc, op.getRhs(), i);
1448 Value extrACC = rewriter.
create<ExtractOp>(loc, op.getAcc(), i);
1449 Value fma = rewriter.
create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1450 desc = rewriter.
create<InsertOp>(loc, fma, desc, i);
1459 static std::optional<SmallVector<int64_t, 4>>
1460 computeContiguousStrides(MemRefType memRefType) {
1463 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1464 return std::nullopt;
1465 if (!strides.empty() && strides.back() != 1)
1466 return std::nullopt;
1468 if (memRefType.getLayout().isIdentity())
1475 auto sizes = memRefType.getShape();
1476 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
1477 if (ShapedType::isDynamic(sizes[index + 1]) ||
1478 ShapedType::isDynamic(strides[index]) ||
1479 ShapedType::isDynamic(strides[index + 1]))
1480 return std::nullopt;
1481 if (strides[index] != strides[index + 1] * sizes[index + 1])
1482 return std::nullopt;
1487 class VectorTypeCastOpConversion
1493 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1495 auto loc = castOp->getLoc();
1496 MemRefType sourceMemRefType =
1497 cast<MemRefType>(castOp.getOperand().getType());
1498 MemRefType targetMemRefType = castOp.getType();
1501 if (!sourceMemRefType.hasStaticShape() ||
1502 !targetMemRefType.hasStaticShape())
1505 auto llvmSourceDescriptorTy =
1506 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1507 if (!llvmSourceDescriptorTy)
1511 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1512 typeConverter->convertType(targetMemRefType));
1513 if (!llvmTargetDescriptorTy)
1517 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1520 auto targetStrides = computeContiguousStrides(targetMemRefType);
1524 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1532 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1533 desc.setAllocatedPtr(rewriter, loc, allocated);
1536 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1537 desc.setAlignedPtr(rewriter, loc, ptr);
1540 auto zero = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, attr);
1541 desc.setOffset(rewriter, loc, zero);
1544 for (
const auto &indexedSize :
1546 int64_t index = indexedSize.index();
1549 auto size = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1550 desc.setSize(rewriter, loc, index, size);
1552 (*targetStrides)[index]);
1553 auto stride = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1554 desc.setStride(rewriter, loc, index, stride);
1564 class VectorCreateMaskOpConversion
1567 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1568 bool enableIndexOpt)
1570 force32BitVectorIndices(enableIndexOpt) {}
1573 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1575 auto dstType = op.getType();
1576 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1578 IntegerType idxType =
1580 auto loc = op->getLoc();
1581 Value indices = rewriter.
create<LLVM::StepVectorOp>(
1585 adaptor.getOperands()[0]);
1587 Value comp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1594 const bool force32BitVectorIndices;
1615 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1617 auto parent =
printOp->getParentOfType<ModuleOp>();
1623 if (
auto value = adaptor.getSource()) {
1629 if (failed(emitScalarPrint(rewriter, parent, loc,
printType, value)))
1633 auto punct =
printOp.getPunctuation();
1634 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1637 *stringLiteral, *getTypeConverter(),
1639 if (createResult.failed())
1642 }
else if (punct != PrintPunctuation::NoPunctuation) {
1643 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1645 case PrintPunctuation::Close:
1647 case PrintPunctuation::Open:
1649 case PrintPunctuation::Comma:
1651 case PrintPunctuation::NewLine:
1654 llvm_unreachable(
"unexpected punctuation");
1659 emitCall(rewriter,
printOp->getLoc(), op.value());
1667 enum class PrintConversion {
1678 Value value)
const {
1679 if (typeConverter->convertType(
printType) ==
nullptr)
1684 FailureOr<Operation *> printer;
1690 conversion = PrintConversion::Bitcast16;
1693 conversion = PrintConversion::Bitcast16;
1697 }
else if (
auto intTy = dyn_cast<IntegerType>(
printType)) {
1701 unsigned width = intTy.getWidth();
1702 if (intTy.isUnsigned()) {
1705 conversion = PrintConversion::ZeroExt64;
1711 assert(intTy.isSignless() || intTy.isSigned());
1716 conversion = PrintConversion::ZeroExt64;
1717 else if (width < 64)
1718 conversion = PrintConversion::SignExt64;
1727 if (failed(printer))
1730 switch (conversion) {
1731 case PrintConversion::ZeroExt64:
1732 value = rewriter.
create<arith::ExtUIOp>(
1735 case PrintConversion::SignExt64:
1736 value = rewriter.
create<arith::ExtSIOp>(
1739 case PrintConversion::Bitcast16:
1740 value = rewriter.
create<LLVM::BitcastOp>(
1746 emitCall(rewriter, loc, printer.value(), value);
1764 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1766 VectorType resultType = cast<VectorType>(splatOp.getType());
1767 if (resultType.getRank() > 1)
1771 auto vectorType = typeConverter->convertType(splatOp.getType());
1773 rewriter.
create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
1774 auto zero = rewriter.
create<LLVM::ConstantOp>(
1780 if (resultType.getRank() == 0) {
1782 splatOp, vectorType, poison, adaptor.getInput(), zero);
1787 auto v = rewriter.
create<LLVM::InsertElementOp>(
1788 splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
1790 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1807 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1809 VectorType resultType = splatOp.getType();
1810 if (resultType.getRank() <= 1)
1814 auto loc = splatOp.getLoc();
1815 auto vectorTypeInfo =
1817 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1818 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1819 if (!llvmNDVectorTy || !llvm1DVectorTy)
1823 Value desc = rewriter.
create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
1827 Value vdesc = rewriter.
create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
1828 auto zero = rewriter.
create<LLVM::ConstantOp>(
1831 Value v = rewriter.
create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1832 adaptor.getInput(), zero);
1835 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1837 v = rewriter.
create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1842 desc = rewriter.
create<LLVM::InsertValueOp>(loc, desc, v, position);
1851 struct VectorInterleaveOpLowering
1856 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1858 VectorType resultType = interleaveOp.getResultVectorType();
1860 if (resultType.getRank() != 1)
1862 "InterleaveOp not rank 1");
1864 if (resultType.isScalable()) {
1866 interleaveOp, typeConverter->convertType(resultType),
1867 adaptor.getLhs(), adaptor.getRhs());
1874 int64_t resultVectorSize = resultType.getNumElements();
1876 interleaveShuffleMask.reserve(resultVectorSize);
1877 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1878 interleaveShuffleMask.push_back(i);
1879 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1882 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1883 interleaveShuffleMask);
1890 struct VectorDeinterleaveOpLowering
1895 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1897 VectorType resultType = deinterleaveOp.getResultVectorType();
1898 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1899 auto loc = deinterleaveOp.getLoc();
1903 if (resultType.getRank() != 1)
1905 "DeinterleaveOp not rank 1");
1907 if (resultType.isScalable()) {
1908 auto llvmTypeConverter = this->getTypeConverter();
1909 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1910 auto packedOpResults =
1911 llvmTypeConverter->packOperationResults(deinterleaveResults);
1912 auto intrinsic = rewriter.
create<LLVM::vector_deinterleave2>(
1913 loc, packedOpResults, adaptor.getSource());
1915 auto evenResult = rewriter.
create<LLVM::ExtractValueOp>(
1916 loc, intrinsic->getResult(0), 0);
1917 auto oddResult = rewriter.
create<LLVM::ExtractValueOp>(
1918 loc, intrinsic->getResult(0), 1);
1927 int64_t resultVectorSize = resultType.getNumElements();
1931 evenShuffleMask.reserve(resultVectorSize);
1932 oddShuffleMask.reserve(resultVectorSize);
1934 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1936 evenShuffleMask.push_back(i);
1938 oddShuffleMask.push_back(i);
1941 auto poison = rewriter.
create<LLVM::PoisonOp>(loc, sourceType);
1942 auto evenShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1943 loc, adaptor.getSource(), poison, evenShuffleMask);
1944 auto oddShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1945 loc, adaptor.getSource(), poison, oddShuffleMask);
1953 struct VectorFromElementsLowering
1958 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1960 Location loc = fromElementsOp.getLoc();
1961 VectorType vectorType = fromElementsOp.getType();
1964 if (vectorType.getRank() > 1)
1966 "rank > 1 vectors are not supported");
1967 Type llvmType = typeConverter->convertType(vectorType);
1968 Value result = rewriter.
create<LLVM::PoisonOp>(loc, llvmType);
1970 result = rewriter.
create<vector::InsertOp>(loc, val, result, idx);
1971 rewriter.
replaceOp(fromElementsOp, result);
1977 struct VectorScalableStepOpLowering
1982 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1984 auto resultType = cast<VectorType>(stepOp.getType());
1985 if (!resultType.isScalable()) {
1988 Type llvmType = typeConverter->convertType(stepOp.getType());
2004 bool reassociateFPReductions,
bool force32BitVectorIndices,
2005 bool useVectorAlignment) {
2008 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2009 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2010 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2011 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2012 VectorLoadStoreConversion<vector::StoreOp>,
2013 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2014 VectorGatherOpConversion, VectorScatterOpConversion>(
2015 converter, useVectorAlignment);
2016 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2017 VectorExtractElementOpConversion, VectorExtractOpConversion,
2018 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
2019 VectorInsertOpConversion, VectorPrintOpConversion,
2020 VectorTypeCastOpConversion, VectorScaleOpConversion,
2021 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2022 VectorSplatOpLowering, VectorSplatNdOpLowering,
2023 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2024 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2025 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2026 VectorScalableStepOpLowering>(converter);
2031 patterns.add<VectorMatmulOpConversion>(converter);
2032 patterns.add<VectorFlatTransposeOpConversion>(converter);
2038 void loadDependentDialects(
MLIRContext *context)
const final {
2039 context->loadDialect<LLVM::LLVMDialect>();
2044 void populateConvertToLLVMConversionPatterns(
2055 dialect->addInterfaces<VectorToLLVMDialectInterface>();
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
const FrozenRewritePatternSet & patterns
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateVectorToLLVMMatrixConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...