31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/Support/Casting.h"
45 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
48 auto constant = LLVM::ConstantOp::create(
51 return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
54 return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
60 Value val,
Type llvmType, int64_t rank, int64_t pos) {
63 auto constant = LLVM::ConstantOp::create(
66 return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
69 return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
74 VectorType vectorType,
unsigned &align) {
76 if (!convertedVectorTy)
79 llvm::LLVMContext llvmContext;
80 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
81 .getPreferredAlignment(convertedVectorTy,
89 MemRefType memrefType,
unsigned &align) {
90 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
96 llvm::LLVMContext llvmContext;
97 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
98 .getPreferredAlignment(elementTy, typeConverter.
getDataLayout());
109 VectorType vectorType,
110 MemRefType memrefType,
unsigned &align,
111 bool useVectorAlignment) {
112 if (useVectorAlignment) {
127 if (!memRefType.isLastDimUnitStride())
137 MemRefType memRefType,
Value llvmMemref,
Value base,
138 Value index, VectorType vectorType) {
140 "unsupported memref type");
141 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
145 vectorType.getScalableDims()[0]);
146 return LLVM::GEPOp::create(
147 rewriter, loc, ptrsType,
148 typeConverter.
convertType(memRefType.getElementType()), base, index);
155 if (
auto attr = dyn_cast<Attribute>(foldResult)) {
156 auto intAttr = cast<IntegerAttr>(attr);
157 return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
160 return cast<Value>(foldResult);
166 using VectorScaleOpConversion =
170 class VectorBitCastOpConversion
176 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
179 VectorType resultTy = bitCastOp.getResultVectorType();
180 if (resultTy.getRank() > 1)
182 Type newResultTy = typeConverter->convertType(resultTy);
184 adaptor.getOperands()[0]);
192 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193 vector::LoadOpAdaptor adaptor,
194 VectorType vectorTy,
Value ptr,
unsigned align,
198 loadOp.getNontemporal());
201 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202 vector::MaskedLoadOpAdaptor adaptor,
203 VectorType vectorTy,
Value ptr,
unsigned align,
206 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
209 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210 vector::StoreOpAdaptor adaptor,
211 VectorType vectorTy,
Value ptr,
unsigned align,
215 storeOp.getNontemporal());
218 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219 vector::MaskedStoreOpAdaptor adaptor,
220 VectorType vectorTy,
Value ptr,
unsigned align,
223 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
228 template <
class LoadOrStoreOp>
234 useVectorAlignment(useVectorAlign) {}
238 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239 typename LoadOrStoreOp::Adaptor adaptor,
242 VectorType vectorTy = loadOrStoreOp.getVectorType();
243 if (vectorTy.getRank() > 1)
246 auto loc = loadOrStoreOp->getLoc();
247 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
250 unsigned align = loadOrStoreOp.getAlignment().value_or(0);
253 memRefTy, align, useVectorAlignment)))
255 "could not resolve alignment");
258 auto vtype = cast<VectorType>(
259 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
261 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
262 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
272 const bool useVectorAlignment;
276 class VectorGatherOpConversion
282 useVectorAlignment(useVectorAlign) {}
286 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
289 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
290 assert(memRefType &&
"The base should be bufferized");
295 VectorType vType = gather.getVectorType();
296 if (vType.getRank() > 1) {
298 gather,
"only 1-D vectors can be lowered to LLVM");
304 memRefType, align, useVectorAlignment)))
309 adaptor.getBase(), adaptor.getOffsets());
310 Value base = adaptor.getBase();
312 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
313 base, ptr, adaptor.getIndices(), vType);
317 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
327 const bool useVectorAlignment;
331 class VectorScatterOpConversion
337 useVectorAlignment(useVectorAlign) {}
342 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
344 auto loc = scatter->getLoc();
345 MemRefType memRefType = scatter.getMemRefType();
350 VectorType vType = scatter.getVectorType();
351 if (vType.getRank() > 1) {
353 scatter,
"only 1-D vectors can be lowered to LLVM");
359 memRefType, align, useVectorAlignment)))
361 "could not resolve alignment");
365 adaptor.getBase(), adaptor.getOffsets());
367 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
368 adaptor.getBase(), ptr, adaptor.getIndices(), vType);
372 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
382 const bool useVectorAlignment;
386 class VectorExpandLoadOpConversion
392 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
394 auto loc = expand->getLoc();
395 MemRefType memRefType = expand.getMemRefType();
398 auto vtype = typeConverter->convertType(expand.getVectorType());
400 adaptor.getBase(), adaptor.getIndices());
403 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
409 class VectorCompressStoreOpConversion
415 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
417 auto loc = compress->getLoc();
418 MemRefType memRefType = compress.getMemRefType();
422 adaptor.getBase(), adaptor.getIndices());
425 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
431 class ReductionNeutralZero {};
432 class ReductionNeutralIntOne {};
433 class ReductionNeutralFPOne {};
434 class ReductionNeutralAllOnes {};
435 class ReductionNeutralSIntMin {};
436 class ReductionNeutralUIntMin {};
437 class ReductionNeutralSIntMax {};
438 class ReductionNeutralUIntMax {};
439 class ReductionNeutralFPMin {};
440 class ReductionNeutralFPMax {};
443 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
446 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
451 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
454 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
459 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
462 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
467 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
470 return LLVM::ConstantOp::create(
471 rewriter, loc, llvmType,
477 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
480 return LLVM::ConstantOp::create(
481 rewriter, loc, llvmType,
487 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
490 return LLVM::ConstantOp::create(
491 rewriter, loc, llvmType,
497 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
500 return LLVM::ConstantOp::create(
501 rewriter, loc, llvmType,
507 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
510 return LLVM::ConstantOp::create(
511 rewriter, loc, llvmType,
517 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
520 auto floatType = cast<FloatType>(llvmType);
521 return LLVM::ConstantOp::create(
522 rewriter, loc, llvmType,
524 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
529 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
532 auto floatType = cast<FloatType>(llvmType);
533 return LLVM::ConstantOp::create(
534 rewriter, loc, llvmType,
536 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
542 template <
class ReductionNeutral>
549 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
558 VectorType vType = cast<VectorType>(llvmType);
559 auto vShape = vType.getShape();
560 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
562 Value baseVecLength = LLVM::ConstantOp::create(
566 if (!vType.getScalableDims()[0])
567 return baseVecLength;
570 Value vScale = vector::VectorScaleOp::create(rewriter, loc);
572 arith::IndexCastOp::create(rewriter, loc, rewriter.
getI32Type(), vScale);
573 Value scalableVecLength =
574 arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
575 return scalableVecLength;
582 template <
class LLVMRedIntrinOp,
class ScalarOp>
583 static Value createIntegerReductionArithmeticOpLowering(
588 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
591 result = ScalarOp::create(rewriter, loc, accumulator, result);
599 template <
class LLVMRedIntrinOp>
600 static Value createIntegerReductionComparisonOpLowering(
602 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
604 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
607 LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result);
608 result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result);
614 template <
typename Source>
615 struct VectorToScalarMapper;
617 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
618 using Type = LLVM::MaximumOp;
621 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
622 using Type = LLVM::MinimumOp;
625 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
626 using Type = LLVM::MaxNumOp;
629 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
630 using Type = LLVM::MinNumOp;
634 template <
class LLVMRedIntrinOp>
635 static Value createFPReductionComparisonOpLowering(
637 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
639 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
642 result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
643 rewriter, loc, result, accumulator);
650 class MaskNeutralFMaximum {};
651 class MaskNeutralFMinimum {};
655 getMaskNeutralValue(MaskNeutralFMaximum,
656 const llvm::fltSemantics &floatSemantics) {
657 return llvm::APFloat::getSmallest(floatSemantics,
true);
661 getMaskNeutralValue(MaskNeutralFMinimum,
662 const llvm::fltSemantics &floatSemantics) {
663 return llvm::APFloat::getLargest(floatSemantics,
false);
667 template <
typename MaskNeutral>
671 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
672 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
674 return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
681 template <
class LLVMRedIntrinOp,
class MaskNeutral>
686 Value mask, LLVM::FastmathFlagsAttr fmf) {
687 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
688 rewriter, loc, llvmType, vectorOperand.
getType());
689 const Value selectedVectorByMask = LLVM::SelectOp::create(
690 rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
691 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
692 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
695 template <
class LLVMRedIntrinOp,
class ReductionNeutral>
699 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
700 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
701 llvmType, accumulator);
702 return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
703 accumulator, vectorOperand,
710 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
715 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
716 llvmType, accumulator);
717 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
718 accumulator, vectorOperand);
721 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
722 static Value lowerPredicatedReductionWithStartValue(
725 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
726 llvmType, accumulator);
728 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
729 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
730 accumulator, vectorOperand,
734 template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
735 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
736 static Value lowerPredicatedReductionWithStartValue(
740 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
741 IntReductionNeutral>(
742 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
745 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
747 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
751 class VectorReductionOpConversion
755 bool reassociateFPRed)
757 reassociateFPReductions(reassociateFPRed) {}
760 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
762 auto kind = reductionOp.getKind();
763 Type eltType = reductionOp.getDest().getType();
764 Type llvmType = typeConverter->convertType(eltType);
765 Value operand = adaptor.getVector();
766 Value acc = adaptor.getAcc();
767 Location loc = reductionOp.getLoc();
773 case vector::CombiningKind::ADD:
775 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
777 rewriter, loc, llvmType, operand, acc);
779 case vector::CombiningKind::MUL:
781 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
783 rewriter, loc, llvmType, operand, acc);
786 result = createIntegerReductionComparisonOpLowering<
787 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
788 LLVM::ICmpPredicate::ule);
790 case vector::CombiningKind::MINSI:
791 result = createIntegerReductionComparisonOpLowering<
792 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
793 LLVM::ICmpPredicate::sle);
795 case vector::CombiningKind::MAXUI:
796 result = createIntegerReductionComparisonOpLowering<
797 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
798 LLVM::ICmpPredicate::uge);
800 case vector::CombiningKind::MAXSI:
801 result = createIntegerReductionComparisonOpLowering<
802 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
803 LLVM::ICmpPredicate::sge);
805 case vector::CombiningKind::AND:
807 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
809 rewriter, loc, llvmType, operand, acc);
811 case vector::CombiningKind::OR:
813 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
815 rewriter, loc, llvmType, operand, acc);
817 case vector::CombiningKind::XOR:
819 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
821 rewriter, loc, llvmType, operand, acc);
831 if (!isa<FloatType>(eltType))
834 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
836 reductionOp.getContext(),
839 reductionOp.getContext(),
840 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
841 : LLVM::FastmathFlags::none));
845 if (
kind == vector::CombiningKind::ADD) {
846 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
847 ReductionNeutralZero>(
848 rewriter, loc, llvmType, operand, acc, fmf);
849 }
else if (
kind == vector::CombiningKind::MUL) {
850 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
851 ReductionNeutralFPOne>(
852 rewriter, loc, llvmType, operand, acc, fmf);
853 }
else if (
kind == vector::CombiningKind::MINIMUMF) {
855 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
856 rewriter, loc, llvmType, operand, acc, fmf);
857 }
else if (
kind == vector::CombiningKind::MAXIMUMF) {
859 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
860 rewriter, loc, llvmType, operand, acc, fmf);
861 }
else if (
kind == vector::CombiningKind::MINNUMF) {
862 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
863 rewriter, loc, llvmType, operand, acc, fmf);
864 }
else if (
kind == vector::CombiningKind::MAXNUMF) {
865 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
866 rewriter, loc, llvmType, operand, acc, fmf);
876 const bool reassociateFPReductions;
887 template <
class MaskedOp>
888 class VectorMaskOpConversionBase
894 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
897 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
900 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
904 virtual LogicalResult
905 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
906 vector::MaskableOpInterface maskableOp,
910 class MaskedReductionOpConversion
911 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
914 using VectorMaskOpConversionBase<
915 vector::ReductionOp>::VectorMaskOpConversionBase;
917 LogicalResult matchAndRewriteMaskableOp(
918 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
920 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
921 auto kind = reductionOp.getKind();
922 Type eltType = reductionOp.getDest().getType();
923 Type llvmType = typeConverter->convertType(eltType);
924 Value operand = reductionOp.getVector();
925 Value acc = reductionOp.getAcc();
926 Location loc = reductionOp.getLoc();
928 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
930 reductionOp.getContext(),
935 case vector::CombiningKind::ADD:
936 result = lowerPredicatedReductionWithStartValue<
937 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
938 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
941 case vector::CombiningKind::MUL:
942 result = lowerPredicatedReductionWithStartValue<
943 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
944 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
948 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
949 ReductionNeutralUIntMax>(
950 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
952 case vector::CombiningKind::MINSI:
953 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
954 ReductionNeutralSIntMax>(
955 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
957 case vector::CombiningKind::MAXUI:
958 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
959 ReductionNeutralUIntMin>(
960 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
962 case vector::CombiningKind::MAXSI:
963 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
964 ReductionNeutralSIntMin>(
965 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
967 case vector::CombiningKind::AND:
968 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
969 ReductionNeutralAllOnes>(
970 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
972 case vector::CombiningKind::OR:
973 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
974 ReductionNeutralZero>(
975 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
977 case vector::CombiningKind::XOR:
978 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
979 ReductionNeutralZero>(
980 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
982 case vector::CombiningKind::MINNUMF:
983 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
984 ReductionNeutralFPMax>(
985 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
987 case vector::CombiningKind::MAXNUMF:
988 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
989 ReductionNeutralFPMin>(
990 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
992 case CombiningKind::MAXIMUMF:
993 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
994 MaskNeutralFMaximum>(
995 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
997 case CombiningKind::MINIMUMF:
998 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
999 MaskNeutralFMinimum>(
1000 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1005 rewriter.replaceOp(maskOp, result);
1010 class VectorShuffleOpConversion
1016 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1018 auto loc = shuffleOp->getLoc();
1019 auto v1Type = shuffleOp.getV1VectorType();
1020 auto v2Type = shuffleOp.getV2VectorType();
1021 auto vectorType = shuffleOp.getResultVectorType();
1022 Type llvmType = typeConverter->convertType(vectorType);
1030 int64_t rank = vectorType.getRank();
1032 bool wellFormed0DCase =
1033 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1034 bool wellFormedNDCase =
1035 v1Type.getRank() == rank && v2Type.getRank() == rank;
1036 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1041 if (rank <= 1 && v1Type == v2Type) {
1042 Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1043 rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1044 llvm::to_vector_of<int32_t>(mask));
1045 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
1050 int64_t v1Dim = v1Type.getDimSize(0);
1052 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1053 eltType = arrayType.getElementType();
1055 eltType = cast<VectorType>(llvmType).getElementType();
1056 Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1058 for (int64_t extPos : mask) {
1059 Value value = adaptor.getV1();
1060 if (extPos >= v1Dim) {
1062 value = adaptor.getV2();
1065 eltType, rank, extPos);
1066 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1067 llvmType, rank, insPos++);
1074 class VectorExtractOpConversion
1080 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1082 auto loc = extractOp->getLoc();
1083 auto resultType = extractOp.getResult().getType();
1084 auto llvmResultType = typeConverter->convertType(resultType);
1086 if (!llvmResultType)
1090 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1104 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1108 bool extractsScalar =
static_cast<int64_t
>(positionVec.size()) ==
1109 extractOp.getSourceVectorType().getRank();
1113 if (extractOp.getSourceVectorType().getRank() == 0) {
1115 positionVec.push_back(rewriter.
getZeroAttr(idxType));
1118 Value extracted = adaptor.getVector();
1119 if (extractsAggregate) {
1121 if (extractsScalar) {
1125 position = position.drop_back();
1128 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1131 extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1135 if (extractsScalar) {
1136 extracted = LLVM::ExtractElementOp::create(
1137 rewriter, loc, extracted,
1141 rewriter.
replaceOp(extractOp, extracted);
1165 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1167 VectorType vType = fmaOp.getVectorType();
1168 if (vType.getRank() > 1)
1172 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1177 class VectorInsertOpConversion
1183 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1185 auto loc = insertOp->getLoc();
1186 auto destVectorType = insertOp.getDestVectorType();
1187 auto llvmResultType = typeConverter->convertType(destVectorType);
1189 if (!llvmResultType)
1193 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1215 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1217 bool insertIntoInnermostDim =
1218 static_cast<int64_t
>(positionVec.size()) == destVectorType.getRank();
1221 positionVec.begin(),
1222 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1224 if (destVectorType.getRank() == 0) {
1228 positionOfScalarWithin1DVector = rewriter.
getZeroAttr(idxType);
1229 }
else if (insertIntoInnermostDim) {
1230 positionOfScalarWithin1DVector = positionVec.back();
1236 Value sourceAggregate = adaptor.getValueToStore();
1237 if (insertIntoInnermostDim) {
1240 if (isNestedAggregate) {
1243 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1244 llvm::IsaPred<Attribute>)) {
1248 sourceAggregate = LLVM::ExtractValueOp::create(
1249 rewriter, loc, adaptor.getDest(),
1254 sourceAggregate = adaptor.getDest();
1257 sourceAggregate = LLVM::InsertElementOp::create(
1258 rewriter, loc, sourceAggregate.
getType(), sourceAggregate,
1259 adaptor.getValueToStore(),
1263 Value result = sourceAggregate;
1264 if (isNestedAggregate) {
1265 result = LLVM::InsertValueOp::create(
1266 rewriter, loc, adaptor.getDest(), sourceAggregate,
1276 struct VectorScalableInsertOpLowering
1282 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1285 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1291 struct VectorScalableExtractOpLowering
1297 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1300 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1301 adaptor.getSource(), adaptor.getPos());
1334 setHasBoundedRewriteRecursion();
1337 LogicalResult matchAndRewrite(FMAOp op,
1339 auto vType = op.getVectorType();
1340 if (vType.getRank() < 2)
1343 auto loc = op.getLoc();
1344 auto elemType = vType.getElementType();
1345 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1347 Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1348 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1349 Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1350 Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1351 Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1352 Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1353 desc = InsertOp::create(rewriter, loc, fma, desc, i);
1362 static std::optional<SmallVector<int64_t, 4>>
1363 computeContiguousStrides(MemRefType memRefType) {
1366 if (
failed(memRefType.getStridesAndOffset(strides, offset)))
1367 return std::nullopt;
1368 if (!strides.empty() && strides.back() != 1)
1369 return std::nullopt;
1371 if (memRefType.getLayout().isIdentity())
1378 auto sizes = memRefType.getShape();
1379 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
1380 if (ShapedType::isDynamic(sizes[index + 1]) ||
1381 ShapedType::isDynamic(strides[index]) ||
1382 ShapedType::isDynamic(strides[index + 1]))
1383 return std::nullopt;
1384 if (strides[index] != strides[index + 1] * sizes[index + 1])
1385 return std::nullopt;
1390 class VectorTypeCastOpConversion
1396 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1398 auto loc = castOp->getLoc();
1399 MemRefType sourceMemRefType =
1400 cast<MemRefType>(castOp.getOperand().getType());
1401 MemRefType targetMemRefType = castOp.getType();
1404 if (!sourceMemRefType.hasStaticShape() ||
1405 !targetMemRefType.hasStaticShape())
1408 auto llvmSourceDescriptorTy =
1409 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1410 if (!llvmSourceDescriptorTy)
1414 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1415 typeConverter->convertType(targetMemRefType));
1416 if (!llvmTargetDescriptorTy)
1420 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1423 auto targetStrides = computeContiguousStrides(targetMemRefType);
1427 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1435 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1436 desc.setAllocatedPtr(rewriter, loc, allocated);
1439 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1440 desc.setAlignedPtr(rewriter, loc, ptr);
1443 auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1444 desc.setOffset(rewriter, loc, zero);
1447 for (
const auto &indexedSize :
1449 int64_t index = indexedSize.index();
1452 auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1453 desc.setSize(rewriter, loc, index, size);
1455 (*targetStrides)[index]);
1457 LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1458 desc.setStride(rewriter, loc, index, stride);
1468 class VectorCreateMaskOpConversion
1471 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1472 bool enableIndexOpt)
1474 force32BitVectorIndices(enableIndexOpt) {}
1477 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1479 auto dstType = op.getType();
1480 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1482 IntegerType idxType =
1484 auto loc = op->getLoc();
1485 Value indices = LLVM::StepVectorOp::create(
1490 adaptor.getOperands()[0]);
1491 Value bounds = BroadcastOp::create(rewriter, loc, indices.
getType(), bound);
1492 Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1499 const bool force32BitVectorIndices;
1506 explicit VectorPrintOpConversion(
1510 symbolTables(symbolTables) {}
1526 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1528 auto parent =
printOp->getParentOfType<ModuleOp>();
1534 if (
auto value = adaptor.getSource()) {
1544 auto punct =
printOp.getPunctuation();
1545 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1548 *stringLiteral, *getTypeConverter(),
1550 if (createResult.failed())
1553 }
else if (punct != PrintPunctuation::NoPunctuation) {
1554 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1556 case PrintPunctuation::Close:
1559 case PrintPunctuation::Open:
1562 case PrintPunctuation::Comma:
1565 case PrintPunctuation::NewLine:
1569 llvm_unreachable(
"unexpected punctuation");
1574 emitCall(rewriter,
printOp->getLoc(), op.value());
1582 enum class PrintConversion {
1593 Value value)
const {
1594 if (typeConverter->convertType(
printType) ==
nullptr)
1599 FailureOr<Operation *> printer;
1605 conversion = PrintConversion::Bitcast16;
1608 conversion = PrintConversion::Bitcast16;
1612 }
else if (
auto intTy = dyn_cast<IntegerType>(
printType)) {
1616 unsigned width = intTy.getWidth();
1617 if (intTy.isUnsigned()) {
1620 conversion = PrintConversion::ZeroExt64;
1627 assert(intTy.isSignless() || intTy.isSigned());
1632 conversion = PrintConversion::ZeroExt64;
1633 else if (width < 64)
1634 conversion = PrintConversion::SignExt64;
1647 switch (conversion) {
1648 case PrintConversion::ZeroExt64:
1649 value = arith::ExtUIOp::create(
1652 case PrintConversion::SignExt64:
1653 value = arith::ExtSIOp::create(
1656 case PrintConversion::Bitcast16:
1657 value = LLVM::BitcastOp::create(
1663 emitCall(rewriter, loc, printer.value(), value);
1678 struct VectorBroadcastScalarToLowRankLowering
1683 matchAndRewrite(vector::BroadcastOp
broadcast, OpAdaptor adaptor,
1685 if (isa<VectorType>(
broadcast.getSourceType()))
1687 broadcast,
"broadcast from vector type not handled");
1690 if (resultType.getRank() > 1)
1692 "broadcast to 2+-d handled elsewhere");
1698 auto zero = LLVM::ConstantOp::create(
1704 if (resultType.getRank() == 0) {
1706 broadcast, vectorType, poison, adaptor.getSource(), zero);
1712 LLVM::InsertElementOp::create(rewriter,
broadcast.
getLoc(), vectorType,
1713 poison, adaptor.getSource(), zero);
1729 struct VectorBroadcastScalarToNdLowering
1734 matchAndRewrite(BroadcastOp
broadcast, OpAdaptor adaptor,
1736 if (isa<VectorType>(
broadcast.getSourceType()))
1738 broadcast,
"broadcast from vector type not handled");
1741 if (resultType.getRank() <= 1)
1743 broadcast,
"broadcast to 1-d or 0-d handled elsewhere");
1747 auto vectorTypeInfo =
1749 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1750 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1751 if (!llvmNDVectorTy || !llvm1DVectorTy)
1755 Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1759 Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1760 auto zero = LLVM::ConstantOp::create(
1761 rewriter, loc, typeConverter->convertType(rewriter.
getIntegerType(32)),
1763 Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1764 vdesc, adaptor.getSource(), zero);
1767 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1769 v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1774 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1783 struct VectorInterleaveOpLowering
1788 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1790 VectorType resultType = interleaveOp.getResultVectorType();
1792 if (resultType.getRank() != 1)
1794 "InterleaveOp not rank 1");
1796 if (resultType.isScalable()) {
1798 interleaveOp, typeConverter->convertType(resultType),
1799 adaptor.getLhs(), adaptor.getRhs());
1806 int64_t resultVectorSize = resultType.getNumElements();
1808 interleaveShuffleMask.reserve(resultVectorSize);
1809 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1810 interleaveShuffleMask.push_back(i);
1811 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1814 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1815 interleaveShuffleMask);
1822 struct VectorDeinterleaveOpLowering
1827 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1829 VectorType resultType = deinterleaveOp.getResultVectorType();
1830 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1831 auto loc = deinterleaveOp.getLoc();
1835 if (resultType.getRank() != 1)
1837 "DeinterleaveOp not rank 1");
1839 if (resultType.isScalable()) {
1840 auto llvmTypeConverter = this->getTypeConverter();
1841 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1842 auto packedOpResults =
1843 llvmTypeConverter->packOperationResults(deinterleaveResults);
1844 auto intrinsic = LLVM::vector_deinterleave2::create(
1845 rewriter, loc, packedOpResults, adaptor.getSource());
1847 auto evenResult = LLVM::ExtractValueOp::create(
1848 rewriter, loc, intrinsic->getResult(0), 0);
1849 auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1850 intrinsic->getResult(0), 1);
1859 int64_t resultVectorSize = resultType.getNumElements();
1863 evenShuffleMask.reserve(resultVectorSize);
1864 oddShuffleMask.reserve(resultVectorSize);
1866 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1868 evenShuffleMask.push_back(i);
1870 oddShuffleMask.push_back(i);
1873 auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1874 auto evenShuffle = LLVM::ShuffleVectorOp::create(
1875 rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1876 auto oddShuffle = LLVM::ShuffleVectorOp::create(
1877 rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1885 struct VectorFromElementsLowering
1890 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1892 Location loc = fromElementsOp.getLoc();
1893 VectorType vectorType = fromElementsOp.getType();
1897 if (vectorType.getRank() > 1)
1899 "rank > 1 vectors are not supported");
1900 Type llvmType = typeConverter->convertType(vectorType);
1902 Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1905 LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1906 result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
1909 rewriter.
replaceOp(fromElementsOp, result);
1915 struct VectorToElementsLowering
1920 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1922 Location loc = toElementsOp.getLoc();
1923 auto idxType = typeConverter->convertType(rewriter.
getIndexType());
1924 Value source = adaptor.getSource();
1927 for (
auto [idx, element] :
llvm::enumerate(toElementsOp.getElements())) {
1929 if (element.use_empty())
1932 auto constIdx = LLVM::ConstantOp::create(
1934 auto llvmType = typeConverter->convertType(element.getType());
1936 Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1938 results[idx] = result;
1941 rewriter.
replaceOp(toElementsOp, results);
1947 struct VectorScalableStepOpLowering
1952 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1954 auto resultType = cast<VectorType>(stepOp.getType());
1955 if (!resultType.isScalable()) {
1958 Type llvmType = typeConverter->convertType(stepOp.getType());
1977 class ContractionOpToMatmulOpLowering
1980 using MaskableOpRewritePattern::MaskableOpRewritePattern;
1982 ContractionOpToMatmulOpLowering(
1983 vector::VectorContractLowering vectorContractLowering,
1988 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2010 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2011 vector::ContractionOp op, MaskingOpInterface maskOp,
2017 auto iteratorTypes = op.getIteratorTypes().getValue();
2023 Type opResType = op.getType();
2024 VectorType vecType = dyn_cast<VectorType>(opResType);
2025 if (vecType && vecType.isScalable()) {
2030 Type elementType = op.getLhsType().getElementType();
2034 Type dstElementType = vecType ? vecType.getElementType() : opResType;
2035 if (elementType != dstElementType)
2045 Value lhs = op.getLhs();
2046 auto lhsMap = op.getIndexingMapsArray()[0];
2053 Value rhs = op.getRhs();
2054 auto rhsMap = op.getIndexingMapsArray()[1];
2061 VectorType lhsType = cast<VectorType>(lhs.
getType());
2062 VectorType rhsType = cast<VectorType>(rhs.
getType());
2063 int64_t lhsRows = lhsType.getDimSize(0);
2064 int64_t lhsColumns = lhsType.getDimSize(1);
2065 int64_t rhsColumns = rhsType.getDimSize(1);
2067 Type flattenedLHSType =
2069 lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs);
2071 Type flattenedRHSType =
2073 rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs);
2075 Value mul = LLVM::MatrixMultiplyOp::create(
2078 cast<VectorType>(lhs.
getType()).getElementType()),
2079 lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2081 mul = vector::ShapeCastOp::create(
2088 auto accMap = op.getIndexingMapsArray()[2];
2092 llvm_unreachable(
"invalid contraction semantics");
2094 Value res = isa<IntegerType>(elementType)
2095 ?
static_cast<Value>(
2096 arith::AddIOp::create(rew, loc, op.getAcc(), mul))
2097 :
static_cast<Value>(
2098 arith::AddFOp::create(rew, loc, op.getAcc(), mul));
2104 class TransposeOpToMatrixTransposeOpLowering
2109 LogicalResult matchAndRewrite(vector::TransposeOp op,
2111 auto loc = op.getLoc();
2113 Value input = op.getVector();
2114 VectorType inputType = op.getSourceVectorType();
2115 VectorType resType = op.getResultVectorType();
2117 if (inputType.isScalable())
2119 op,
"This lowering does not support scalable vectors");
2124 if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2128 Type flattenedType =
2131 vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2134 Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2135 matrix,
rows, columns);
2146 matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
2149 adaptor.getInput());
2163 patterns.add<ContractionOpToMatmulOpLowering>(
patterns.getContext(), benefit);
2168 patterns.add<TransposeOpToMatrixTransposeOpLowering>(
patterns.getContext(),
2175 bool reassociateFPReductions,
bool force32BitVectorIndices,
2176 bool useVectorAlignment) {
2179 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2180 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2181 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2182 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2183 VectorLoadStoreConversion<vector::StoreOp>,
2184 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2185 VectorGatherOpConversion, VectorScatterOpConversion>(
2186 converter, useVectorAlignment);
2187 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2188 VectorExtractOpConversion, VectorFMAOp1DConversion,
2189 VectorInsertOpConversion, VectorPrintOpConversion,
2190 VectorTypeCastOpConversion, VectorScaleOpConversion,
2191 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2192 VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
2193 VectorBroadcastScalarToNdLowering,
2194 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2195 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2196 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2197 VectorToElementsLowering, VectorScalableStepOpLowering>(
2204 void loadDependentDialects(
MLIRContext *context)
const final {
2205 context->loadDialect<LLVM::LLVMDialect>();
2210 void populateConvertToLLVMConversionPatterns(
2221 dialect->addInterfaces<VectorToLLVMDialectInterface>();
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.