31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/Support/Casting.h"
43 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
46 auto constant = rewriter.
create<LLVM::ConstantOp>(
49 return rewriter.
create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
52 return rewriter.
create<LLVM::InsertValueOp>(loc, val1, val2, pos);
58 Value val,
Type llvmType, int64_t rank, int64_t pos) {
61 auto constant = rewriter.
create<LLVM::ConstantOp>(
64 return rewriter.
create<LLVM::ExtractElementOp>(loc, llvmType, val,
67 return rewriter.
create<LLVM::ExtractValueOp>(loc, val, pos);
72 VectorType vectorType,
unsigned &align) {
74 if (!convertedVectorTy)
77 llvm::LLVMContext llvmContext;
87 MemRefType memrefType,
unsigned &align) {
88 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
94 llvm::LLVMContext llvmContext;
107 VectorType vectorType,
108 MemRefType memrefType,
unsigned &align,
109 bool useVectorAlignment) {
110 if (useVectorAlignment) {
125 if (!memRefType.isLastDimUnitStride())
135 MemRefType memRefType,
Value llvmMemref,
Value base,
136 Value index, VectorType vectorType) {
138 "unsupported memref type");
139 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
143 vectorType.getScalableDims()[0]);
144 return rewriter.
create<LLVM::GEPOp>(
145 loc, ptrsType, typeConverter.
convertType(memRefType.getElementType()),
153 if (
auto attr = dyn_cast<Attribute>(foldResult)) {
154 auto intAttr = cast<IntegerAttr>(attr);
155 return builder.
create<LLVM::ConstantOp>(loc, intAttr).getResult();
158 return cast<Value>(foldResult);
164 using VectorScaleOpConversion =
168 class VectorBitCastOpConversion
174 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177 VectorType resultTy = bitCastOp.getResultVectorType();
178 if (resultTy.getRank() > 1)
180 Type newResultTy = typeConverter->convertType(resultTy);
182 adaptor.getOperands()[0]);
189 class VectorMatmulOpConversion
195 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
198 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
199 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
200 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
207 class VectorFlatTransposeOpConversion
213 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
216 transOp, typeConverter->convertType(transOp.getRes().getType()),
217 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
225 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
226 vector::LoadOpAdaptor adaptor,
227 VectorType vectorTy,
Value ptr,
unsigned align,
231 loadOp.getNontemporal());
234 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
235 vector::MaskedLoadOpAdaptor adaptor,
236 VectorType vectorTy,
Value ptr,
unsigned align,
239 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
242 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
243 vector::StoreOpAdaptor adaptor,
244 VectorType vectorTy,
Value ptr,
unsigned align,
248 storeOp.getNontemporal());
251 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
252 vector::MaskedStoreOpAdaptor adaptor,
253 VectorType vectorTy,
Value ptr,
unsigned align,
256 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
261 template <
class LoadOrStoreOp>
267 useVectorAlignment(useVectorAlign) {}
271 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
272 typename LoadOrStoreOp::Adaptor adaptor,
275 VectorType vectorTy = loadOrStoreOp.getVectorType();
276 if (vectorTy.getRank() > 1)
279 auto loc = loadOrStoreOp->getLoc();
280 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
285 memRefTy, align, useVectorAlignment)))
287 "could not resolve alignment");
290 auto vtype = cast<VectorType>(
291 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
292 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
293 adaptor.getIndices(), rewriter);
294 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
304 const bool useVectorAlignment;
308 class VectorGatherOpConversion
314 useVectorAlignment(useVectorAlign) {}
318 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
321 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
322 assert(memRefType &&
"The base should be bufferized");
327 VectorType vType = gather.getVectorType();
328 if (vType.getRank() > 1) {
330 gather,
"only 1-D vectors can be lowered to LLVM");
336 memRefType, align, useVectorAlignment)))
340 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
341 adaptor.getIndices(), rewriter);
342 Value base = adaptor.getBase();
344 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
345 base, ptr, adaptor.getIndexVec(), vType);
349 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
359 const bool useVectorAlignment;
363 class VectorScatterOpConversion
369 useVectorAlignment(useVectorAlign) {}
374 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
376 auto loc = scatter->getLoc();
377 MemRefType memRefType = scatter.getMemRefType();
382 VectorType vType = scatter.getVectorType();
383 if (vType.getRank() > 1) {
385 scatter,
"only 1-D vectors can be lowered to LLVM");
391 memRefType, align, useVectorAlignment)))
393 "could not resolve alignment");
396 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397 adaptor.getIndices(), rewriter);
399 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
400 adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
404 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
414 const bool useVectorAlignment;
418 class VectorExpandLoadOpConversion
424 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
426 auto loc = expand->getLoc();
427 MemRefType memRefType = expand.getMemRefType();
430 auto vtype = typeConverter->convertType(expand.getVectorType());
431 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
432 adaptor.getIndices(), rewriter);
435 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
441 class VectorCompressStoreOpConversion
447 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
449 auto loc = compress->getLoc();
450 MemRefType memRefType = compress.getMemRefType();
453 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
454 adaptor.getIndices(), rewriter);
457 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
463 class ReductionNeutralZero {};
464 class ReductionNeutralIntOne {};
465 class ReductionNeutralFPOne {};
466 class ReductionNeutralAllOnes {};
467 class ReductionNeutralSIntMin {};
468 class ReductionNeutralUIntMin {};
469 class ReductionNeutralSIntMax {};
470 class ReductionNeutralUIntMax {};
471 class ReductionNeutralFPMin {};
472 class ReductionNeutralFPMax {};
475 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
478 return rewriter.
create<LLVM::ConstantOp>(loc, llvmType,
483 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
486 return rewriter.
create<LLVM::ConstantOp>(
491 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
494 return rewriter.
create<LLVM::ConstantOp>(
499 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
502 return rewriter.
create<LLVM::ConstantOp>(
509 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
512 return rewriter.
create<LLVM::ConstantOp>(
519 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
522 return rewriter.
create<LLVM::ConstantOp>(
529 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
532 return rewriter.
create<LLVM::ConstantOp>(
539 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
542 return rewriter.
create<LLVM::ConstantOp>(
549 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
552 auto floatType = cast<FloatType>(llvmType);
553 return rewriter.
create<LLVM::ConstantOp>(
556 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
561 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
564 auto floatType = cast<FloatType>(llvmType);
565 return rewriter.
create<LLVM::ConstantOp>(
568 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
574 template <
class ReductionNeutral>
581 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
590 VectorType vType = cast<VectorType>(llvmType);
591 auto vShape = vType.getShape();
592 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
594 Value baseVecLength = rewriter.
create<LLVM::ConstantOp>(
598 if (!vType.getScalableDims()[0])
599 return baseVecLength;
602 Value vScale = rewriter.
create<vector::VectorScaleOp>(loc);
605 Value scalableVecLength =
606 rewriter.
create<arith::MulIOp>(loc, baseVecLength, vScale);
607 return scalableVecLength;
614 template <
class LLVMRedIntrinOp,
class ScalarOp>
615 static Value createIntegerReductionArithmeticOpLowering(
619 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
622 result = rewriter.
create<ScalarOp>(loc, accumulator, result);
630 template <
class LLVMRedIntrinOp>
631 static Value createIntegerReductionComparisonOpLowering(
633 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
634 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
637 rewriter.
create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
638 result = rewriter.
create<LLVM::SelectOp>(loc, cmp, accumulator, result);
644 template <
typename Source>
645 struct VectorToScalarMapper;
647 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
648 using Type = LLVM::MaximumOp;
651 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
652 using Type = LLVM::MinimumOp;
655 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
656 using Type = LLVM::MaxNumOp;
659 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
660 using Type = LLVM::MinNumOp;
664 template <
class LLVMRedIntrinOp>
665 static Value createFPReductionComparisonOpLowering(
667 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
669 rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
673 rewriter.
create<
typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
674 loc, result, accumulator);
681 class MaskNeutralFMaximum {};
682 class MaskNeutralFMinimum {};
686 getMaskNeutralValue(MaskNeutralFMaximum,
687 const llvm::fltSemantics &floatSemantics) {
688 return llvm::APFloat::getSmallest(floatSemantics,
true);
692 getMaskNeutralValue(MaskNeutralFMinimum,
693 const llvm::fltSemantics &floatSemantics) {
694 return llvm::APFloat::getLargest(floatSemantics,
false);
698 template <
typename MaskNeutral>
702 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
703 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
705 return rewriter.
create<LLVM::ConstantOp>(loc, vectorType, denseValue);
712 template <
class LLVMRedIntrinOp,
class MaskNeutral>
717 Value mask, LLVM::FastmathFlagsAttr fmf) {
718 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
719 rewriter, loc, llvmType, vectorOperand.
getType());
720 const Value selectedVectorByMask = rewriter.
create<LLVM::SelectOp>(
721 loc, mask, vectorOperand, vectorMaskNeutral);
722 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
723 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
726 template <
class LLVMRedIntrinOp,
class ReductionNeutral>
730 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
731 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
732 llvmType, accumulator);
733 return rewriter.
create<LLVMRedIntrinOp>(loc, llvmType,
741 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
746 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
747 llvmType, accumulator);
748 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
753 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
754 static Value lowerPredicatedReductionWithStartValue(
757 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
758 llvmType, accumulator);
760 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
761 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
763 vectorOperand, mask, vectorLength);
766 template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
767 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
768 static Value lowerPredicatedReductionWithStartValue(
772 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
773 IntReductionNeutral>(
774 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
777 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
779 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
783 class VectorReductionOpConversion
787 bool reassociateFPRed)
789 reassociateFPReductions(reassociateFPRed) {}
792 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
794 auto kind = reductionOp.getKind();
795 Type eltType = reductionOp.getDest().getType();
796 Type llvmType = typeConverter->convertType(eltType);
797 Value operand = adaptor.getVector();
798 Value acc = adaptor.getAcc();
799 Location loc = reductionOp.getLoc();
805 case vector::CombiningKind::ADD:
807 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
809 rewriter, loc, llvmType, operand, acc);
811 case vector::CombiningKind::MUL:
813 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
815 rewriter, loc, llvmType, operand, acc);
818 result = createIntegerReductionComparisonOpLowering<
819 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
820 LLVM::ICmpPredicate::ule);
822 case vector::CombiningKind::MINSI:
823 result = createIntegerReductionComparisonOpLowering<
824 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
825 LLVM::ICmpPredicate::sle);
827 case vector::CombiningKind::MAXUI:
828 result = createIntegerReductionComparisonOpLowering<
829 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
830 LLVM::ICmpPredicate::uge);
832 case vector::CombiningKind::MAXSI:
833 result = createIntegerReductionComparisonOpLowering<
834 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
835 LLVM::ICmpPredicate::sge);
837 case vector::CombiningKind::AND:
839 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
841 rewriter, loc, llvmType, operand, acc);
843 case vector::CombiningKind::OR:
845 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
847 rewriter, loc, llvmType, operand, acc);
849 case vector::CombiningKind::XOR:
851 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
853 rewriter, loc, llvmType, operand, acc);
863 if (!isa<FloatType>(eltType))
866 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
868 reductionOp.getContext(),
871 reductionOp.getContext(),
872 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
873 : LLVM::FastmathFlags::none));
877 if (
kind == vector::CombiningKind::ADD) {
878 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
879 ReductionNeutralZero>(
880 rewriter, loc, llvmType, operand, acc, fmf);
881 }
else if (
kind == vector::CombiningKind::MUL) {
882 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
883 ReductionNeutralFPOne>(
884 rewriter, loc, llvmType, operand, acc, fmf);
885 }
else if (
kind == vector::CombiningKind::MINIMUMF) {
887 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
888 rewriter, loc, llvmType, operand, acc, fmf);
889 }
else if (
kind == vector::CombiningKind::MAXIMUMF) {
891 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
892 rewriter, loc, llvmType, operand, acc, fmf);
893 }
else if (
kind == vector::CombiningKind::MINNUMF) {
894 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
895 rewriter, loc, llvmType, operand, acc, fmf);
896 }
else if (
kind == vector::CombiningKind::MAXNUMF) {
897 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
898 rewriter, loc, llvmType, operand, acc, fmf);
907 const bool reassociateFPReductions;
918 template <
class MaskedOp>
919 class VectorMaskOpConversionBase
925 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
928 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
931 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
935 virtual LogicalResult
936 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
937 vector::MaskableOpInterface maskableOp,
941 class MaskedReductionOpConversion
942 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
945 using VectorMaskOpConversionBase<
946 vector::ReductionOp>::VectorMaskOpConversionBase;
948 LogicalResult matchAndRewriteMaskableOp(
949 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
951 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
952 auto kind = reductionOp.getKind();
953 Type eltType = reductionOp.getDest().getType();
954 Type llvmType = typeConverter->convertType(eltType);
955 Value operand = reductionOp.getVector();
956 Value acc = reductionOp.getAcc();
957 Location loc = reductionOp.getLoc();
959 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
961 reductionOp.getContext(),
966 case vector::CombiningKind::ADD:
967 result = lowerPredicatedReductionWithStartValue<
968 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
969 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
972 case vector::CombiningKind::MUL:
973 result = lowerPredicatedReductionWithStartValue<
974 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
975 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
979 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
980 ReductionNeutralUIntMax>(
981 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
983 case vector::CombiningKind::MINSI:
984 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
985 ReductionNeutralSIntMax>(
986 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
988 case vector::CombiningKind::MAXUI:
989 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
990 ReductionNeutralUIntMin>(
991 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
993 case vector::CombiningKind::MAXSI:
994 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
995 ReductionNeutralSIntMin>(
996 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
998 case vector::CombiningKind::AND:
999 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
1000 ReductionNeutralAllOnes>(
1001 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1003 case vector::CombiningKind::OR:
1004 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
1005 ReductionNeutralZero>(
1006 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1008 case vector::CombiningKind::XOR:
1009 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
1010 ReductionNeutralZero>(
1011 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1013 case vector::CombiningKind::MINNUMF:
1014 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1015 ReductionNeutralFPMax>(
1016 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1018 case vector::CombiningKind::MAXNUMF:
1019 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1020 ReductionNeutralFPMin>(
1021 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1023 case CombiningKind::MAXIMUMF:
1024 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1025 MaskNeutralFMaximum>(
1026 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1028 case CombiningKind::MINIMUMF:
1029 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1030 MaskNeutralFMinimum>(
1031 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1036 rewriter.replaceOp(maskOp, result);
1041 class VectorShuffleOpConversion
1047 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1049 auto loc = shuffleOp->getLoc();
1050 auto v1Type = shuffleOp.getV1VectorType();
1051 auto v2Type = shuffleOp.getV2VectorType();
1052 auto vectorType = shuffleOp.getResultVectorType();
1053 Type llvmType = typeConverter->convertType(vectorType);
1061 int64_t rank = vectorType.getRank();
1063 bool wellFormed0DCase =
1064 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1065 bool wellFormedNDCase =
1066 v1Type.getRank() == rank && v2Type.getRank() == rank;
1067 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1072 if (rank <= 1 && v1Type == v2Type) {
1073 Value llvmShuffleOp = rewriter.
create<LLVM::ShuffleVectorOp>(
1074 loc, adaptor.getV1(), adaptor.getV2(),
1075 llvm::to_vector_of<int32_t>(mask));
1076 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
1081 int64_t v1Dim = v1Type.getDimSize(0);
1083 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1084 eltType = arrayType.getElementType();
1086 eltType = cast<VectorType>(llvmType).getElementType();
1087 Value insert = rewriter.
create<LLVM::PoisonOp>(loc, llvmType);
1089 for (int64_t extPos : mask) {
1090 Value value = adaptor.getV1();
1091 if (extPos >= v1Dim) {
1093 value = adaptor.getV2();
1096 eltType, rank, extPos);
1097 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1098 llvmType, rank, insPos++);
1105 class VectorExtractElementOpConversion
1112 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1114 auto vectorType = extractEltOp.getSourceVectorType();
1115 auto llvmType = typeConverter->convertType(vectorType.getElementType());
1121 if (vectorType.getRank() == 0) {
1122 Location loc = extractEltOp.getLoc();
1124 auto zero = rewriter.
create<LLVM::ConstantOp>(
1125 loc, typeConverter->convertType(idxType),
1128 extractEltOp, llvmType, adaptor.getVector(), zero);
1133 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1138 class VectorExtractOpConversion
1144 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1146 auto loc = extractOp->getLoc();
1147 auto resultType = extractOp.getResult().getType();
1148 auto llvmResultType = typeConverter->convertType(resultType);
1150 if (!llvmResultType)
1154 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1168 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1172 bool extractsScalar =
static_cast<int64_t
>(positionVec.size()) ==
1173 extractOp.getSourceVectorType().getRank();
1177 if (extractOp.getSourceVectorType().getRank() == 0) {
1179 positionVec.push_back(rewriter.
getZeroAttr(idxType));
1182 Value extracted = adaptor.getVector();
1183 if (extractsAggregate) {
1185 if (extractsScalar) {
1189 position = position.drop_back();
1192 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1195 extracted = rewriter.
create<LLVM::ExtractValueOp>(
1199 if (extractsScalar) {
1200 extracted = rewriter.
create<LLVM::ExtractElementOp>(
1201 loc, extracted,
getAsLLVMValue(rewriter, loc, positionVec.back()));
1204 rewriter.
replaceOp(extractOp, extracted);
1228 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1230 VectorType vType = fmaOp.getVectorType();
1231 if (vType.getRank() > 1)
1235 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1240 class VectorInsertElementOpConversion
1246 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1248 auto vectorType = insertEltOp.getDestVectorType();
1249 auto llvmType = typeConverter->convertType(vectorType);
1255 if (vectorType.getRank() == 0) {
1256 Location loc = insertEltOp.getLoc();
1258 auto zero = rewriter.
create<LLVM::ConstantOp>(
1259 loc, typeConverter->convertType(idxType),
1262 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1267 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1268 adaptor.getPosition());
1273 class VectorInsertOpConversion
1279 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1281 auto loc = insertOp->getLoc();
1282 auto destVectorType = insertOp.getDestVectorType();
1283 auto llvmResultType = typeConverter->convertType(destVectorType);
1285 if (!llvmResultType)
1289 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1311 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1313 bool insertIntoInnermostDim =
1314 static_cast<int64_t
>(positionVec.size()) == destVectorType.getRank();
1317 positionVec.begin(),
1318 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1320 if (destVectorType.getRank() == 0) {
1324 positionOfScalarWithin1DVector = rewriter.
getZeroAttr(idxType);
1325 }
else if (insertIntoInnermostDim) {
1326 positionOfScalarWithin1DVector = positionVec.back();
1332 Value sourceAggregate = adaptor.getValueToStore();
1333 if (insertIntoInnermostDim) {
1336 if (isNestedAggregate) {
1339 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1340 llvm::IsaPred<Attribute>)) {
1344 sourceAggregate = rewriter.
create<LLVM::ExtractValueOp>(
1345 loc, adaptor.getDest(),
1350 sourceAggregate = adaptor.getDest();
1353 sourceAggregate = rewriter.
create<LLVM::InsertElementOp>(
1354 loc, sourceAggregate.
getType(), sourceAggregate,
1355 adaptor.getValueToStore(),
1359 Value result = sourceAggregate;
1360 if (isNestedAggregate) {
1361 result = rewriter.
create<LLVM::InsertValueOp>(
1362 loc, adaptor.getDest(), sourceAggregate,
1372 struct VectorScalableInsertOpLowering
1378 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1381 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1387 struct VectorScalableExtractOpLowering
1393 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1396 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1397 adaptor.getSource(), adaptor.getPos());
1430 setHasBoundedRewriteRecursion();
1433 LogicalResult matchAndRewrite(FMAOp op,
1435 auto vType = op.getVectorType();
1436 if (vType.getRank() < 2)
1439 auto loc = op.getLoc();
1440 auto elemType = vType.getElementType();
1443 Value desc = rewriter.
create<vector::SplatOp>(loc, vType, zero);
1444 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1445 Value extrLHS = rewriter.
create<ExtractOp>(loc, op.getLhs(), i);
1446 Value extrRHS = rewriter.
create<ExtractOp>(loc, op.getRhs(), i);
1447 Value extrACC = rewriter.
create<ExtractOp>(loc, op.getAcc(), i);
1448 Value fma = rewriter.
create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1449 desc = rewriter.
create<InsertOp>(loc, fma, desc, i);
1458 static std::optional<SmallVector<int64_t, 4>>
1459 computeContiguousStrides(MemRefType memRefType) {
1462 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1463 return std::nullopt;
1464 if (!strides.empty() && strides.back() != 1)
1465 return std::nullopt;
1467 if (memRefType.getLayout().isIdentity())
1474 auto sizes = memRefType.getShape();
1475 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
1476 if (ShapedType::isDynamic(sizes[index + 1]) ||
1477 ShapedType::isDynamic(strides[index]) ||
1478 ShapedType::isDynamic(strides[index + 1]))
1479 return std::nullopt;
1480 if (strides[index] != strides[index + 1] * sizes[index + 1])
1481 return std::nullopt;
1486 class VectorTypeCastOpConversion
1492 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1494 auto loc = castOp->getLoc();
1495 MemRefType sourceMemRefType =
1496 cast<MemRefType>(castOp.getOperand().getType());
1497 MemRefType targetMemRefType = castOp.getType();
1500 if (!sourceMemRefType.hasStaticShape() ||
1501 !targetMemRefType.hasStaticShape())
1504 auto llvmSourceDescriptorTy =
1505 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1506 if (!llvmSourceDescriptorTy)
1510 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1511 typeConverter->convertType(targetMemRefType));
1512 if (!llvmTargetDescriptorTy)
1516 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1519 auto targetStrides = computeContiguousStrides(targetMemRefType);
1523 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1531 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1532 desc.setAllocatedPtr(rewriter, loc, allocated);
1535 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1536 desc.setAlignedPtr(rewriter, loc, ptr);
1539 auto zero = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, attr);
1540 desc.setOffset(rewriter, loc, zero);
1543 for (
const auto &indexedSize :
1545 int64_t index = indexedSize.index();
1548 auto size = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1549 desc.setSize(rewriter, loc, index, size);
1551 (*targetStrides)[index]);
1552 auto stride = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1553 desc.setStride(rewriter, loc, index, stride);
1563 class VectorCreateMaskOpConversion
1566 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1567 bool enableIndexOpt)
1569 force32BitVectorIndices(enableIndexOpt) {}
1572 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1574 auto dstType = op.getType();
1575 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1577 IntegerType idxType =
1579 auto loc = op->getLoc();
1580 Value indices = rewriter.
create<LLVM::StepVectorOp>(
1584 adaptor.getOperands()[0]);
1586 Value comp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1593 const bool force32BitVectorIndices;
1614 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1616 auto parent =
printOp->getParentOfType<ModuleOp>();
1622 if (
auto value = adaptor.getSource()) {
1628 if (failed(emitScalarPrint(rewriter, parent, loc,
printType, value)))
1632 auto punct =
printOp.getPunctuation();
1633 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1636 *stringLiteral, *getTypeConverter(),
1638 if (createResult.failed())
1641 }
else if (punct != PrintPunctuation::NoPunctuation) {
1642 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1644 case PrintPunctuation::Close:
1646 case PrintPunctuation::Open:
1648 case PrintPunctuation::Comma:
1650 case PrintPunctuation::NewLine:
1653 llvm_unreachable(
"unexpected punctuation");
1658 emitCall(rewriter,
printOp->getLoc(), op.value());
1666 enum class PrintConversion {
1677 Value value)
const {
1678 if (typeConverter->convertType(
printType) ==
nullptr)
1683 FailureOr<Operation *> printer;
1689 conversion = PrintConversion::Bitcast16;
1692 conversion = PrintConversion::Bitcast16;
1696 }
else if (
auto intTy = dyn_cast<IntegerType>(
printType)) {
1700 unsigned width = intTy.getWidth();
1701 if (intTy.isUnsigned()) {
1704 conversion = PrintConversion::ZeroExt64;
1710 assert(intTy.isSignless() || intTy.isSigned());
1715 conversion = PrintConversion::ZeroExt64;
1716 else if (width < 64)
1717 conversion = PrintConversion::SignExt64;
1726 if (failed(printer))
1729 switch (conversion) {
1730 case PrintConversion::ZeroExt64:
1731 value = rewriter.
create<arith::ExtUIOp>(
1734 case PrintConversion::SignExt64:
1735 value = rewriter.
create<arith::ExtSIOp>(
1738 case PrintConversion::Bitcast16:
1739 value = rewriter.
create<LLVM::BitcastOp>(
1745 emitCall(rewriter, loc, printer.value(), value);
1763 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1765 VectorType resultType = cast<VectorType>(splatOp.getType());
1766 if (resultType.getRank() > 1)
1770 auto vectorType = typeConverter->convertType(splatOp.getType());
1772 rewriter.
create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
1773 auto zero = rewriter.
create<LLVM::ConstantOp>(
1779 if (resultType.getRank() == 0) {
1781 splatOp, vectorType, poison, adaptor.getInput(), zero);
1786 auto v = rewriter.
create<LLVM::InsertElementOp>(
1787 splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
1789 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1806 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1808 VectorType resultType = splatOp.getType();
1809 if (resultType.getRank() <= 1)
1813 auto loc = splatOp.getLoc();
1814 auto vectorTypeInfo =
1816 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1817 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1818 if (!llvmNDVectorTy || !llvm1DVectorTy)
1822 Value desc = rewriter.
create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
1826 Value vdesc = rewriter.
create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
1827 auto zero = rewriter.
create<LLVM::ConstantOp>(
1830 Value v = rewriter.
create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1831 adaptor.getInput(), zero);
1834 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1836 v = rewriter.
create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1841 desc = rewriter.
create<LLVM::InsertValueOp>(loc, desc, v, position);
1850 struct VectorInterleaveOpLowering
1855 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1857 VectorType resultType = interleaveOp.getResultVectorType();
1859 if (resultType.getRank() != 1)
1861 "InterleaveOp not rank 1");
1863 if (resultType.isScalable()) {
1865 interleaveOp, typeConverter->convertType(resultType),
1866 adaptor.getLhs(), adaptor.getRhs());
1873 int64_t resultVectorSize = resultType.getNumElements();
1875 interleaveShuffleMask.reserve(resultVectorSize);
1876 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1877 interleaveShuffleMask.push_back(i);
1878 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1881 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1882 interleaveShuffleMask);
1889 struct VectorDeinterleaveOpLowering
1894 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1896 VectorType resultType = deinterleaveOp.getResultVectorType();
1897 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1898 auto loc = deinterleaveOp.getLoc();
1902 if (resultType.getRank() != 1)
1904 "DeinterleaveOp not rank 1");
1906 if (resultType.isScalable()) {
1907 auto llvmTypeConverter = this->getTypeConverter();
1908 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1909 auto packedOpResults =
1910 llvmTypeConverter->packOperationResults(deinterleaveResults);
1911 auto intrinsic = rewriter.
create<LLVM::vector_deinterleave2>(
1912 loc, packedOpResults, adaptor.getSource());
1914 auto evenResult = rewriter.
create<LLVM::ExtractValueOp>(
1915 loc, intrinsic->getResult(0), 0);
1916 auto oddResult = rewriter.
create<LLVM::ExtractValueOp>(
1917 loc, intrinsic->getResult(0), 1);
1926 int64_t resultVectorSize = resultType.getNumElements();
1930 evenShuffleMask.reserve(resultVectorSize);
1931 oddShuffleMask.reserve(resultVectorSize);
1933 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1935 evenShuffleMask.push_back(i);
1937 oddShuffleMask.push_back(i);
1940 auto poison = rewriter.
create<LLVM::PoisonOp>(loc, sourceType);
1941 auto evenShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1942 loc, adaptor.getSource(), poison, evenShuffleMask);
1943 auto oddShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1944 loc, adaptor.getSource(), poison, oddShuffleMask);
1952 struct VectorFromElementsLowering
1957 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1959 Location loc = fromElementsOp.getLoc();
1960 VectorType vectorType = fromElementsOp.getType();
1963 if (vectorType.getRank() > 1)
1965 "rank > 1 vectors are not supported");
1966 Type llvmType = typeConverter->convertType(vectorType);
1967 Value result = rewriter.
create<LLVM::PoisonOp>(loc, llvmType);
1969 result = rewriter.
create<vector::InsertOp>(loc, val, result, idx);
1970 rewriter.
replaceOp(fromElementsOp, result);
1976 struct VectorScalableStepOpLowering
1981 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1983 auto resultType = cast<VectorType>(stepOp.getType());
1984 if (!resultType.isScalable()) {
1987 Type llvmType = typeConverter->convertType(stepOp.getType());
2003 bool reassociateFPReductions,
bool force32BitVectorIndices,
2004 bool useVectorAlignment) {
2007 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2008 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2009 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2010 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2011 VectorLoadStoreConversion<vector::StoreOp>,
2012 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2013 VectorGatherOpConversion, VectorScatterOpConversion>(
2014 converter, useVectorAlignment);
2015 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2016 VectorExtractElementOpConversion, VectorExtractOpConversion,
2017 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
2018 VectorInsertOpConversion, VectorPrintOpConversion,
2019 VectorTypeCastOpConversion, VectorScaleOpConversion,
2020 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2021 VectorSplatOpLowering, VectorSplatNdOpLowering,
2022 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2023 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2024 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2025 VectorScalableStepOpLowering>(converter);
2030 patterns.add<VectorMatmulOpConversion>(converter);
2031 patterns.add<VectorFlatTransposeOpConversion>(converter);
2037 void loadDependentDialects(
MLIRContext *context)
const final {
2038 context->loadDialect<LLVM::LLVMDialect>();
2043 void populateConvertToLLVMConversionPatterns(
2054 dialect->addInterfaces<VectorToLLVMDialectInterface>();
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
const FrozenRewritePatternSet & patterns
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateVectorToLLVMMatrixConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...