30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
39 assert((tp.getRank() > 1) &&
"unlowerable vector type");
41 tp.getScalableDims().take_back());
49 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
52 auto constant = rewriter.
create<LLVM::ConstantOp>(
55 return rewriter.
create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58 return rewriter.
create<LLVM::InsertValueOp>(loc, val1, val2, pos);
64 Value val,
Type llvmType, int64_t rank, int64_t pos) {
67 auto constant = rewriter.
create<LLVM::ConstantOp>(
70 return rewriter.
create<LLVM::ExtractElementOp>(loc, llvmType, val,
73 return rewriter.
create<LLVM::ExtractValueOp>(loc, val, pos);
78 MemRefType memrefType,
unsigned &align) {
79 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
85 llvm::LLVMContext llvmContext;
104 MemRefType memRefType,
Value llvmMemref,
Value base,
105 Value index, uint64_t vLen) {
107 "unsupported memref type");
110 return rewriter.
create<LLVM::GEPOp>(
111 loc, ptrsType, typeConverter.
convertType(memRefType.getElementType()),
119 if (
auto attr = foldResult.dyn_cast<
Attribute>()) {
120 auto intAttr = cast<IntegerAttr>(attr);
121 return builder.
create<LLVM::ConstantOp>(loc, intAttr).getResult();
124 return foldResult.get<
Value>();
130 using VectorScaleOpConversion =
134 class VectorBitCastOpConversion
140 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
143 VectorType resultTy = bitCastOp.getResultVectorType();
144 if (resultTy.getRank() > 1)
146 Type newResultTy = typeConverter->convertType(resultTy);
148 adaptor.getOperands()[0]);
155 class VectorMatmulOpConversion
161 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
164 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
165 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
166 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
173 class VectorFlatTransposeOpConversion
179 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
182 transOp, typeConverter->convertType(transOp.getRes().getType()),
183 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
191 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
192 vector::LoadOpAdaptor adaptor,
193 VectorType vectorTy,
Value ptr,
unsigned align,
198 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
199 vector::MaskedLoadOpAdaptor adaptor,
200 VectorType vectorTy,
Value ptr,
unsigned align,
203 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
206 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
207 vector::StoreOpAdaptor adaptor,
208 VectorType vectorTy,
Value ptr,
unsigned align,
214 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
215 vector::MaskedStoreOpAdaptor adaptor,
216 VectorType vectorTy,
Value ptr,
unsigned align,
219 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
224 template <
class LoadOrStoreOp,
class LoadOrStoreOpAdaptor>
230 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
231 typename LoadOrStoreOp::Adaptor adaptor,
234 VectorType vectorTy = loadOrStoreOp.getVectorType();
235 if (vectorTy.getRank() > 1)
238 auto loc = loadOrStoreOp->getLoc();
239 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
247 auto vtype = cast<VectorType>(
248 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
249 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
250 adaptor.getIndices(), rewriter);
251 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
258 class VectorGatherOpConversion
264 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
266 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
267 assert(memRefType &&
"The base should be bufferized");
272 auto loc = gather->getLoc();
279 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
280 adaptor.getIndices(), rewriter);
281 Value base = adaptor.getBase();
283 auto llvmNDVectorTy = adaptor.getIndexVec().
getType();
285 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
286 auto vType = gather.getVectorType();
289 memRefType, base, ptr, adaptor.getIndexVec(),
290 vType.getDimSize(0));
293 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
299 auto callback = [align, memRefType, base, ptr, loc, &rewriter,
300 &typeConverter](
Type llvm1DVectorTy,
304 rewriter, loc, typeConverter, memRefType, base, ptr,
308 return rewriter.create<LLVM::masked_gather>(
309 loc, llvm1DVectorTy, ptrs, vectorOperands[1],
310 vectorOperands[2], rewriter.getI32IntegerAttr(align));
313 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
315 gather, vectorOperands, *getTypeConverter(), callback, rewriter);
320 class VectorScatterOpConversion
326 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
328 auto loc = scatter->getLoc();
329 MemRefType memRefType = scatter.getMemRefType();
340 VectorType vType = scatter.getVectorType();
341 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
342 adaptor.getIndices(), rewriter);
344 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
345 ptr, adaptor.getIndexVec(), vType.getDimSize(0));
349 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
356 class VectorExpandLoadOpConversion
362 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
364 auto loc = expand->getLoc();
365 MemRefType memRefType = expand.getMemRefType();
368 auto vtype = typeConverter->
convertType(expand.getVectorType());
369 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
370 adaptor.getIndices(), rewriter);
373 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
379 class VectorCompressStoreOpConversion
385 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
387 auto loc = compress->getLoc();
388 MemRefType memRefType = compress.getMemRefType();
391 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
392 adaptor.getIndices(), rewriter);
395 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
401 class ReductionNeutralZero {};
402 class ReductionNeutralIntOne {};
403 class ReductionNeutralFPOne {};
404 class ReductionNeutralAllOnes {};
405 class ReductionNeutralSIntMin {};
406 class ReductionNeutralUIntMin {};
407 class ReductionNeutralSIntMax {};
408 class ReductionNeutralUIntMax {};
409 class ReductionNeutralFPMin {};
410 class ReductionNeutralFPMax {};
413 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
416 return rewriter.
create<LLVM::ConstantOp>(loc, llvmType,
421 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
424 return rewriter.
create<LLVM::ConstantOp>(
429 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
432 return rewriter.
create<LLVM::ConstantOp>(
437 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
440 return rewriter.
create<LLVM::ConstantOp>(
447 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
450 return rewriter.
create<LLVM::ConstantOp>(
457 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
460 return rewriter.
create<LLVM::ConstantOp>(
467 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
470 return rewriter.
create<LLVM::ConstantOp>(
477 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
480 return rewriter.
create<LLVM::ConstantOp>(
487 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
490 auto floatType = cast<FloatType>(llvmType);
491 return rewriter.
create<LLVM::ConstantOp>(
494 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
499 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
502 auto floatType = cast<FloatType>(llvmType);
503 return rewriter.
create<LLVM::ConstantOp>(
506 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
512 template <
class ReductionNeutral>
519 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
528 VectorType vType = cast<VectorType>(llvmType);
529 auto vShape = vType.getShape();
530 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
532 return rewriter.
create<LLVM::ConstantOp>(
541 template <
class LLVMRedIntrinOp,
class ScalarOp>
542 static Value createIntegerReductionArithmeticOpLowering(
546 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
549 result = rewriter.
create<ScalarOp>(loc, accumulator, result);
557 template <
class LLVMRedIntrinOp>
558 static Value createIntegerReductionComparisonOpLowering(
560 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
561 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
564 rewriter.
create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
565 result = rewriter.
create<LLVM::SelectOp>(loc, cmp, accumulator, result);
571 template <
typename Source>
572 struct VectorToScalarMapper;
574 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
575 using Type = LLVM::MaximumOp;
578 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
579 using Type = LLVM::MinimumOp;
582 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
583 using Type = LLVM::MaxNumOp;
586 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
587 using Type = LLVM::MinNumOp;
591 template <
class LLVMRedIntrinOp>
592 static Value createFPReductionComparisonOpLowering(
594 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
596 rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
600 rewriter.
create<
typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
601 loc, result, accumulator);
608 class MaskNeutralFMaximum {};
609 class MaskNeutralFMinimum {};
613 getMaskNeutralValue(MaskNeutralFMaximum,
614 const llvm::fltSemantics &floatSemantics) {
615 return llvm::APFloat::getSmallest(floatSemantics,
true);
619 getMaskNeutralValue(MaskNeutralFMinimum,
620 const llvm::fltSemantics &floatSemantics) {
621 return llvm::APFloat::getLargest(floatSemantics,
false);
625 template <
typename MaskNeutral>
629 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
630 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
633 return rewriter.
create<LLVM::ConstantOp>(loc, vectorType, denseValue);
640 template <
class LLVMRedIntrinOp,
class MaskNeutral>
645 Value mask, LLVM::FastmathFlagsAttr fmf) {
646 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
647 rewriter, loc, llvmType, vectorOperand.
getType());
648 const Value selectedVectorByMask = rewriter.
create<LLVM::SelectOp>(
649 loc, mask, vectorOperand, vectorMaskNeutral);
650 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
651 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
654 template <
class LLVMRedIntrinOp,
class ReductionNeutral>
658 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
659 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
660 llvmType, accumulator);
661 return rewriter.
create<LLVMRedIntrinOp>(loc, llvmType,
669 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
674 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
675 llvmType, accumulator);
676 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
681 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
682 static Value lowerPredicatedReductionWithStartValue(
685 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
686 llvmType, accumulator);
688 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
689 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
691 vectorOperand, mask, vectorLength);
694 template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
695 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
696 static Value lowerPredicatedReductionWithStartValue(
700 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
701 IntReductionNeutral>(
702 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
705 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
707 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
711 class VectorReductionOpConversion
715 bool reassociateFPRed)
717 reassociateFPReductions(reassociateFPRed) {}
720 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
722 auto kind = reductionOp.getKind();
723 Type eltType = reductionOp.getDest().getType();
725 Value operand = adaptor.getVector();
726 Value acc = adaptor.getAcc();
727 Location loc = reductionOp.getLoc();
733 case vector::CombiningKind::ADD:
735 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
737 rewriter, loc, llvmType, operand, acc);
739 case vector::CombiningKind::MUL:
741 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
743 rewriter, loc, llvmType, operand, acc);
745 case vector::CombiningKind::MINUI:
746 result = createIntegerReductionComparisonOpLowering<
747 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
748 LLVM::ICmpPredicate::ule);
750 case vector::CombiningKind::MINSI:
751 result = createIntegerReductionComparisonOpLowering<
752 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
753 LLVM::ICmpPredicate::sle);
755 case vector::CombiningKind::MAXUI:
756 result = createIntegerReductionComparisonOpLowering<
757 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
758 LLVM::ICmpPredicate::uge);
760 case vector::CombiningKind::MAXSI:
761 result = createIntegerReductionComparisonOpLowering<
762 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
763 LLVM::ICmpPredicate::sge);
765 case vector::CombiningKind::AND:
767 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
769 rewriter, loc, llvmType, operand, acc);
771 case vector::CombiningKind::OR:
773 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
775 rewriter, loc, llvmType, operand, acc);
777 case vector::CombiningKind::XOR:
779 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
781 rewriter, loc, llvmType, operand, acc);
791 if (!isa<FloatType>(eltType))
794 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
796 reductionOp.getContext(),
799 reductionOp.getContext(),
800 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
801 : LLVM::FastmathFlags::none));
805 if (kind == vector::CombiningKind::ADD) {
806 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
807 ReductionNeutralZero>(
808 rewriter, loc, llvmType, operand, acc, fmf);
809 }
else if (kind == vector::CombiningKind::MUL) {
810 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
811 ReductionNeutralFPOne>(
812 rewriter, loc, llvmType, operand, acc, fmf);
813 }
else if (kind == vector::CombiningKind::MINIMUMF) {
815 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
816 rewriter, loc, llvmType, operand, acc, fmf);
817 }
else if (kind == vector::CombiningKind::MAXIMUMF) {
819 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
820 rewriter, loc, llvmType, operand, acc, fmf);
821 }
else if (kind == vector::CombiningKind::MINF) {
822 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
823 rewriter, loc, llvmType, operand, acc, fmf);
824 }
else if (kind == vector::CombiningKind::MAXF) {
825 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
826 rewriter, loc, llvmType, operand, acc, fmf);
835 const bool reassociateFPReductions;
846 template <
class MaskedOp>
847 class VectorMaskOpConversionBase
853 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
856 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
859 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
864 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
865 vector::MaskableOpInterface maskableOp,
869 class MaskedReductionOpConversion
870 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
873 using VectorMaskOpConversionBase<
874 vector::ReductionOp>::VectorMaskOpConversionBase;
877 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
879 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
880 auto kind = reductionOp.getKind();
881 Type eltType = reductionOp.getDest().getType();
883 Value operand = reductionOp.getVector();
884 Value acc = reductionOp.getAcc();
885 Location loc = reductionOp.getLoc();
887 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
889 reductionOp.getContext(),
894 case vector::CombiningKind::ADD:
895 result = lowerPredicatedReductionWithStartValue<
896 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
897 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
900 case vector::CombiningKind::MUL:
901 result = lowerPredicatedReductionWithStartValue<
902 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
903 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
906 case vector::CombiningKind::MINUI:
907 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
908 ReductionNeutralUIntMax>(
909 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
911 case vector::CombiningKind::MINSI:
912 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
913 ReductionNeutralSIntMax>(
914 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
916 case vector::CombiningKind::MAXUI:
917 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
918 ReductionNeutralUIntMin>(
919 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
921 case vector::CombiningKind::MAXSI:
922 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
923 ReductionNeutralSIntMin>(
924 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
926 case vector::CombiningKind::AND:
927 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
928 ReductionNeutralAllOnes>(
929 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
931 case vector::CombiningKind::OR:
932 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
933 ReductionNeutralZero>(
934 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
936 case vector::CombiningKind::XOR:
937 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
938 ReductionNeutralZero>(
939 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
941 case vector::CombiningKind::MINF:
942 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
943 ReductionNeutralFPMax>(
944 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
946 case vector::CombiningKind::MAXF:
947 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
948 ReductionNeutralFPMin>(
949 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
951 case CombiningKind::MAXIMUMF:
952 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
953 MaskNeutralFMaximum>(
954 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
956 case CombiningKind::MINIMUMF:
957 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
958 MaskNeutralFMinimum>(
959 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
964 rewriter.replaceOp(maskOp, result);
969 class VectorShuffleOpConversion
975 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
977 auto loc = shuffleOp->getLoc();
978 auto v1Type = shuffleOp.getV1VectorType();
979 auto v2Type = shuffleOp.getV2VectorType();
980 auto vectorType = shuffleOp.getResultVectorType();
982 auto maskArrayAttr = shuffleOp.getMask();
989 int64_t rank = vectorType.getRank();
991 bool wellFormed0DCase =
992 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
993 bool wellFormedNDCase =
994 v1Type.getRank() == rank && v2Type.getRank() == rank;
995 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1000 if (rank <= 1 && v1Type == v2Type) {
1001 Value llvmShuffleOp = rewriter.
create<LLVM::ShuffleVectorOp>(
1002 loc, adaptor.getV1(), adaptor.getV2(),
1003 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1004 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
1009 int64_t v1Dim = v1Type.getDimSize(0);
1011 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1012 eltType = arrayType.getElementType();
1014 eltType = cast<VectorType>(llvmType).getElementType();
1015 Value insert = rewriter.
create<LLVM::UndefOp>(loc, llvmType);
1018 int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1019 Value value = adaptor.getV1();
1020 if (extPos >= v1Dim) {
1022 value = adaptor.getV2();
1025 eltType, rank, extPos);
1026 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1027 llvmType, rank, insPos++);
1034 class VectorExtractElementOpConversion
1041 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1043 auto vectorType = extractEltOp.getSourceVectorType();
1044 auto llvmType = typeConverter->
convertType(vectorType.getElementType());
1050 if (vectorType.getRank() == 0) {
1051 Location loc = extractEltOp.getLoc();
1053 auto zero = rewriter.
create<LLVM::ConstantOp>(
1057 extractEltOp, llvmType, adaptor.getVector(), zero);
1062 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1067 class VectorExtractOpConversion
1073 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1075 auto loc = extractOp->getLoc();
1076 auto resultType = extractOp.getResult().getType();
1077 auto llvmResultType = typeConverter->
convertType(resultType);
1079 if (!llvmResultType)
1083 for (
auto [idx, pos] :
llvm::enumerate(extractOp.getMixedPosition())) {
1084 if (pos.is<
Value>())
1086 positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1088 positionVec.push_back(pos);
1093 if (position.empty()) {
1094 rewriter.
replaceOp(extractOp, adaptor.getVector());
1099 if (isa<VectorType>(resultType)) {
1100 if (extractOp.hasDynamicPosition())
1103 Value extracted = rewriter.
create<LLVM::ExtractValueOp>(
1105 rewriter.
replaceOp(extractOp, extracted);
1110 Value extracted = adaptor.getVector();
1111 if (position.size() > 1) {
1112 if (extractOp.hasDynamicPosition())
1117 extracted = rewriter.
create<LLVM::ExtractValueOp>(loc, extracted,
1148 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1150 VectorType vType = fmaOp.getVectorType();
1151 if (vType.getRank() > 1)
1155 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1160 class VectorInsertElementOpConversion
1166 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1168 auto vectorType = insertEltOp.getDestVectorType();
1169 auto llvmType = typeConverter->
convertType(vectorType);
1175 if (vectorType.getRank() == 0) {
1176 Location loc = insertEltOp.getLoc();
1178 auto zero = rewriter.
create<LLVM::ConstantOp>(
1182 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1187 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1188 adaptor.getPosition());
1193 class VectorInsertOpConversion
1199 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1201 auto loc = insertOp->getLoc();
1202 auto sourceType = insertOp.getSourceType();
1203 auto destVectorType = insertOp.getDestVectorType();
1204 auto llvmResultType = typeConverter->
convertType(destVectorType);
1206 if (!llvmResultType)
1210 for (
auto [idx, pos] :
llvm::enumerate(insertOp.getMixedPosition())) {
1211 if (pos.is<
Value>())
1213 positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1215 positionVec.push_back(pos);
1221 if (position.empty()) {
1222 rewriter.
replaceOp(insertOp, adaptor.getSource());
1227 if (isa<VectorType>(sourceType)) {
1228 if (insertOp.hasDynamicPosition())
1231 Value inserted = rewriter.
create<LLVM::InsertValueOp>(
1232 loc, adaptor.getDest(), adaptor.getSource(),
getAsIntegers(position));
1238 Value extracted = adaptor.getDest();
1239 auto oneDVectorType = destVectorType;
1240 if (position.size() > 1) {
1241 if (insertOp.hasDynamicPosition())
1245 extracted = rewriter.
create<LLVM::ExtractValueOp>(
1250 Value inserted = rewriter.
create<LLVM::InsertElementOp>(
1251 loc, typeConverter->
convertType(oneDVectorType), extracted,
1252 adaptor.getSource(),
getAsLLVMValue(rewriter, loc, position.back()));
1255 if (position.size() > 1) {
1256 if (insertOp.hasDynamicPosition())
1259 inserted = rewriter.
create<LLVM::InsertValueOp>(
1260 loc, adaptor.getDest(), inserted,
1270 struct VectorScalableInsertOpLowering
1276 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1279 insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1285 struct VectorScalableExtractOpLowering
1291 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1294 extOp, typeConverter->
convertType(extOp.getResultVectorType()),
1295 adaptor.getSource(), adaptor.getPos());
1328 setHasBoundedRewriteRecursion();
1333 auto vType = op.getVectorType();
1334 if (vType.getRank() < 2)
1338 auto elemType = vType.getElementType();
1341 Value desc = rewriter.
create<vector::SplatOp>(loc, vType, zero);
1342 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1343 Value extrLHS = rewriter.
create<ExtractOp>(loc, op.getLhs(), i);
1344 Value extrRHS = rewriter.
create<ExtractOp>(loc, op.getRhs(), i);
1345 Value extrACC = rewriter.
create<ExtractOp>(loc, op.getAcc(), i);
1346 Value fma = rewriter.
create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1347 desc = rewriter.
create<InsertOp>(loc, fma, desc, i);
1356 static std::optional<SmallVector<int64_t, 4>>
1357 computeContiguousStrides(MemRefType memRefType) {
1361 return std::nullopt;
1362 if (!strides.empty() && strides.back() != 1)
1363 return std::nullopt;
1365 if (memRefType.getLayout().isIdentity())
1372 auto sizes = memRefType.getShape();
1373 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
1374 if (ShapedType::isDynamic(sizes[index + 1]) ||
1375 ShapedType::isDynamic(strides[index]) ||
1376 ShapedType::isDynamic(strides[index + 1]))
1377 return std::nullopt;
1378 if (strides[index] != strides[index + 1] * sizes[index + 1])
1379 return std::nullopt;
1384 class VectorTypeCastOpConversion
1390 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1392 auto loc = castOp->getLoc();
1393 MemRefType sourceMemRefType =
1394 cast<MemRefType>(castOp.getOperand().getType());
1395 MemRefType targetMemRefType = castOp.getType();
1398 if (!sourceMemRefType.hasStaticShape() ||
1399 !targetMemRefType.hasStaticShape())
1402 auto llvmSourceDescriptorTy =
1403 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1404 if (!llvmSourceDescriptorTy)
1408 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1410 if (!llvmTargetDescriptorTy)
1414 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1417 auto targetStrides = computeContiguousStrides(targetMemRefType);
1421 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1429 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1430 desc.setAllocatedPtr(rewriter, loc, allocated);
1433 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1434 desc.setAlignedPtr(rewriter, loc, ptr);
1437 auto zero = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, attr);
1438 desc.setOffset(rewriter, loc, zero);
1441 for (
const auto &indexedSize :
1443 int64_t index = indexedSize.index();
1446 auto size = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1447 desc.setSize(rewriter, loc, index, size);
1449 (*targetStrides)[index]);
1450 auto stride = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1451 desc.setStride(rewriter, loc, index, stride);
1461 class VectorCreateMaskOpRewritePattern
1464 explicit VectorCreateMaskOpRewritePattern(
MLIRContext *context,
1465 bool enableIndexOpt)
1467 force32BitVectorIndices(enableIndexOpt) {}
1471 auto dstType = op.getType();
1472 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1474 IntegerType idxType =
1477 Value indices = rewriter.
create<LLVM::StepVectorOp>(
1483 Value comp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1490 const bool force32BitVectorIndices;
1511 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1513 auto parent =
printOp->getParentOfType<ModuleOp>();
1519 if (
auto value = adaptor.getSource()) {
1529 auto punct =
printOp.getPunctuation();
1530 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1532 *stringLiteral, *getTypeConverter());
1533 }
else if (punct != PrintPunctuation::NoPunctuation) {
1534 emitCall(rewriter,
printOp->getLoc(), [&] {
1536 case PrintPunctuation::Close:
1537 return LLVM::lookupOrCreatePrintCloseFn(parent);
1538 case PrintPunctuation::Open:
1539 return LLVM::lookupOrCreatePrintOpenFn(parent);
1540 case PrintPunctuation::Comma:
1541 return LLVM::lookupOrCreatePrintCommaFn(parent);
1542 case PrintPunctuation::NewLine:
1543 return LLVM::lookupOrCreatePrintNewlineFn(parent);
1545 llvm_unreachable(
"unexpected punctuation");
1555 enum class PrintConversion {
1566 Value value)
const {
1578 conversion = PrintConversion::Bitcast16;
1581 conversion = PrintConversion::Bitcast16;
1585 }
else if (
auto intTy = dyn_cast<IntegerType>(
printType)) {
1589 unsigned width = intTy.getWidth();
1590 if (intTy.isUnsigned()) {
1593 conversion = PrintConversion::ZeroExt64;
1599 assert(intTy.isSignless() || intTy.isSigned());
1604 conversion = PrintConversion::ZeroExt64;
1605 else if (width < 64)
1606 conversion = PrintConversion::SignExt64;
1616 switch (conversion) {
1617 case PrintConversion::ZeroExt64:
1618 value = rewriter.
create<arith::ExtUIOp>(
1621 case PrintConversion::SignExt64:
1622 value = rewriter.
create<arith::ExtSIOp>(
1625 case PrintConversion::Bitcast16:
1626 value = rewriter.
create<LLVM::BitcastOp>(
1632 emitCall(rewriter, loc, printer, value);
1650 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1652 VectorType resultType = cast<VectorType>(splatOp.getType());
1653 if (resultType.getRank() > 1)
1657 auto vectorType = typeConverter->
convertType(splatOp.getType());
1658 Value undef = rewriter.
create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1659 auto zero = rewriter.
create<LLVM::ConstantOp>(
1665 if (resultType.getRank() == 0) {
1667 splatOp, vectorType, undef, adaptor.getInput(), zero);
1672 auto v = rewriter.
create<LLVM::InsertElementOp>(
1673 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1675 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1692 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1694 VectorType resultType = splatOp.getType();
1695 if (resultType.getRank() <= 1)
1699 auto loc = splatOp.getLoc();
1700 auto vectorTypeInfo =
1702 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1703 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1704 if (!llvmNDVectorTy || !llvm1DVectorTy)
1708 Value desc = rewriter.
create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1712 Value vdesc = rewriter.
create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1713 auto zero = rewriter.
create<LLVM::ConstantOp>(
1716 Value v = rewriter.
create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1717 adaptor.getInput(), zero);
1720 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1722 v = rewriter.
create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1727 desc = rewriter.
create<LLVM::InsertValueOp>(loc, desc, v, position);
1739 bool reassociateFPReductions,
bool force32BitVectorIndices) {
1741 patterns.
add<VectorFMAOpNDRewritePattern>(ctx);
1743 patterns.
add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1744 patterns.
add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1746 .
add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1747 VectorExtractElementOpConversion, VectorExtractOpConversion,
1748 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1749 VectorInsertOpConversion, VectorPrintOpConversion,
1750 VectorTypeCastOpConversion, VectorScaleOpConversion,
1751 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1752 VectorLoadStoreConversion<vector::MaskedLoadOp,
1753 vector::MaskedLoadOpAdaptor>,
1754 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1755 VectorLoadStoreConversion<vector::MaskedStoreOp,
1756 vector::MaskedStoreOpAdaptor>,
1757 VectorGatherOpConversion, VectorScatterOpConversion,
1758 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1759 VectorSplatOpLowering, VectorSplatNdOpLowering,
1760 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1761 MaskedReductionOpConversion>(converter);
1768 patterns.
add<VectorMatmulOpConversion>(converter);
1769 patterns.
add<VectorFlatTransposeOpConversion>(converter);
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, uint64_t vLen)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
static VectorType reducedVectorTypeBack(VectorType tp)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp)
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp)
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...