30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
39 assert((tp.getRank() > 1) &&
"unlowerable vector type");
41 tp.getScalableDims().take_back());
49 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
52 auto constant = rewriter.
create<LLVM::ConstantOp>(
55 return rewriter.
create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58 return rewriter.
create<LLVM::InsertValueOp>(loc, val1, val2, pos);
64 Value val,
Type llvmType, int64_t rank, int64_t pos) {
67 auto constant = rewriter.
create<LLVM::ConstantOp>(
70 return rewriter.
create<LLVM::ExtractElementOp>(loc, llvmType, val,
73 return rewriter.
create<LLVM::ExtractValueOp>(loc, val, pos);
78 MemRefType memrefType,
unsigned &align) {
79 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
85 llvm::LLVMContext llvmContext;
104 MemRefType memRefType,
Value llvmMemref,
Value base,
105 Value index, VectorType vectorType) {
107 "unsupported memref type");
108 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
112 vectorType.getScalableDims()[0]);
113 return rewriter.
create<LLVM::GEPOp>(
114 loc, ptrsType, typeConverter.
convertType(memRefType.getElementType()),
122 if (
auto attr = foldResult.dyn_cast<
Attribute>()) {
123 auto intAttr = cast<IntegerAttr>(attr);
124 return builder.
create<LLVM::ConstantOp>(loc, intAttr).getResult();
127 return foldResult.get<
Value>();
133 using VectorScaleOpConversion =
137 class VectorBitCastOpConversion
143 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
146 VectorType resultTy = bitCastOp.getResultVectorType();
147 if (resultTy.getRank() > 1)
149 Type newResultTy = typeConverter->convertType(resultTy);
151 adaptor.getOperands()[0]);
158 class VectorMatmulOpConversion
164 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
167 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
168 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
169 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
176 class VectorFlatTransposeOpConversion
182 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
185 transOp, typeConverter->convertType(transOp.getRes().getType()),
186 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
194 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
195 vector::LoadOpAdaptor adaptor,
196 VectorType vectorTy,
Value ptr,
unsigned align,
200 loadOp.getNontemporal());
203 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
204 vector::MaskedLoadOpAdaptor adaptor,
205 VectorType vectorTy,
Value ptr,
unsigned align,
208 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
211 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
212 vector::StoreOpAdaptor adaptor,
213 VectorType vectorTy,
Value ptr,
unsigned align,
217 storeOp.getNontemporal());
220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
221 vector::MaskedStoreOpAdaptor adaptor,
222 VectorType vectorTy,
Value ptr,
unsigned align,
225 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
230 template <
class LoadOrStoreOp>
236 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
237 typename LoadOrStoreOp::Adaptor adaptor,
240 VectorType vectorTy = loadOrStoreOp.getVectorType();
241 if (vectorTy.getRank() > 1)
244 auto loc = loadOrStoreOp->getLoc();
245 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
253 auto vtype = cast<VectorType>(
254 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
255 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
256 adaptor.getIndices(), rewriter);
257 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
264 class VectorGatherOpConversion
270 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
272 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
273 assert(memRefType &&
"The base should be bufferized");
278 auto loc = gather->getLoc();
285 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
286 adaptor.getIndices(), rewriter);
287 Value base = adaptor.getBase();
289 auto llvmNDVectorTy = adaptor.getIndexVec().
getType();
291 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
292 auto vType = gather.getVectorType();
295 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
296 base, ptr, adaptor.getIndexVec(), vType);
299 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
305 auto callback = [align, memRefType, base, ptr, loc, &rewriter,
306 &typeConverter](
Type llvm1DVectorTy,
310 rewriter, loc, typeConverter, memRefType, base, ptr,
311 vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
313 return rewriter.create<LLVM::masked_gather>(
314 loc, llvm1DVectorTy, ptrs, vectorOperands[1],
315 vectorOperands[2], rewriter.getI32IntegerAttr(align));
318 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
320 gather, vectorOperands, *getTypeConverter(), callback, rewriter);
325 class VectorScatterOpConversion
331 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
333 auto loc = scatter->getLoc();
334 MemRefType memRefType = scatter.getMemRefType();
345 VectorType vType = scatter.getVectorType();
346 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
347 adaptor.getIndices(), rewriter);
349 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
350 adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
354 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
361 class VectorExpandLoadOpConversion
367 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
369 auto loc = expand->getLoc();
370 MemRefType memRefType = expand.getMemRefType();
373 auto vtype = typeConverter->
convertType(expand.getVectorType());
374 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
375 adaptor.getIndices(), rewriter);
378 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
384 class VectorCompressStoreOpConversion
390 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
392 auto loc = compress->getLoc();
393 MemRefType memRefType = compress.getMemRefType();
396 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397 adaptor.getIndices(), rewriter);
400 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
406 class ReductionNeutralZero {};
407 class ReductionNeutralIntOne {};
408 class ReductionNeutralFPOne {};
409 class ReductionNeutralAllOnes {};
410 class ReductionNeutralSIntMin {};
411 class ReductionNeutralUIntMin {};
412 class ReductionNeutralSIntMax {};
413 class ReductionNeutralUIntMax {};
414 class ReductionNeutralFPMin {};
415 class ReductionNeutralFPMax {};
418 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
421 return rewriter.
create<LLVM::ConstantOp>(loc, llvmType,
426 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
429 return rewriter.
create<LLVM::ConstantOp>(
434 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
437 return rewriter.
create<LLVM::ConstantOp>(
442 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
445 return rewriter.
create<LLVM::ConstantOp>(
452 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
455 return rewriter.
create<LLVM::ConstantOp>(
462 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
465 return rewriter.
create<LLVM::ConstantOp>(
472 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
475 return rewriter.
create<LLVM::ConstantOp>(
482 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
485 return rewriter.
create<LLVM::ConstantOp>(
492 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
495 auto floatType = cast<FloatType>(llvmType);
496 return rewriter.
create<LLVM::ConstantOp>(
499 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
504 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
507 auto floatType = cast<FloatType>(llvmType);
508 return rewriter.
create<LLVM::ConstantOp>(
511 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
517 template <
class ReductionNeutral>
524 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
533 VectorType vType = cast<VectorType>(llvmType);
534 auto vShape = vType.getShape();
535 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
537 Value baseVecLength = rewriter.
create<LLVM::ConstantOp>(
541 if (!vType.getScalableDims()[0])
542 return baseVecLength;
545 Value vScale = rewriter.
create<vector::VectorScaleOp>(loc);
548 Value scalableVecLength =
549 rewriter.
create<arith::MulIOp>(loc, baseVecLength, vScale);
550 return scalableVecLength;
557 template <
class LLVMRedIntrinOp,
class ScalarOp>
558 static Value createIntegerReductionArithmeticOpLowering(
562 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
565 result = rewriter.
create<ScalarOp>(loc, accumulator, result);
573 template <
class LLVMRedIntrinOp>
574 static Value createIntegerReductionComparisonOpLowering(
576 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
577 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
580 rewriter.
create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
581 result = rewriter.
create<LLVM::SelectOp>(loc, cmp, accumulator, result);
587 template <
typename Source>
588 struct VectorToScalarMapper;
590 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
591 using Type = LLVM::MaximumOp;
594 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
595 using Type = LLVM::MinimumOp;
598 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
599 using Type = LLVM::MaxNumOp;
602 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
603 using Type = LLVM::MinNumOp;
607 template <
class LLVMRedIntrinOp>
608 static Value createFPReductionComparisonOpLowering(
610 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
612 rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
616 rewriter.
create<
typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
617 loc, result, accumulator);
624 class MaskNeutralFMaximum {};
625 class MaskNeutralFMinimum {};
629 getMaskNeutralValue(MaskNeutralFMaximum,
630 const llvm::fltSemantics &floatSemantics) {
631 return llvm::APFloat::getSmallest(floatSemantics,
true);
635 getMaskNeutralValue(MaskNeutralFMinimum,
636 const llvm::fltSemantics &floatSemantics) {
637 return llvm::APFloat::getLargest(floatSemantics,
false);
641 template <
typename MaskNeutral>
645 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
646 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
648 return rewriter.
create<LLVM::ConstantOp>(loc, vectorType, denseValue);
655 template <
class LLVMRedIntrinOp,
class MaskNeutral>
660 Value mask, LLVM::FastmathFlagsAttr fmf) {
661 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
662 rewriter, loc, llvmType, vectorOperand.
getType());
663 const Value selectedVectorByMask = rewriter.
create<LLVM::SelectOp>(
664 loc, mask, vectorOperand, vectorMaskNeutral);
665 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
666 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
669 template <
class LLVMRedIntrinOp,
class ReductionNeutral>
673 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
674 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
675 llvmType, accumulator);
676 return rewriter.
create<LLVMRedIntrinOp>(loc, llvmType,
684 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
689 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
690 llvmType, accumulator);
691 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
696 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
697 static Value lowerPredicatedReductionWithStartValue(
700 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
701 llvmType, accumulator);
703 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
704 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
706 vectorOperand, mask, vectorLength);
709 template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
710 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
711 static Value lowerPredicatedReductionWithStartValue(
715 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
716 IntReductionNeutral>(
717 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
720 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
722 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
726 class VectorReductionOpConversion
730 bool reassociateFPRed)
732 reassociateFPReductions(reassociateFPRed) {}
735 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
737 auto kind = reductionOp.getKind();
738 Type eltType = reductionOp.getDest().getType();
740 Value operand = adaptor.getVector();
741 Value acc = adaptor.getAcc();
742 Location loc = reductionOp.getLoc();
748 case vector::CombiningKind::ADD:
750 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
752 rewriter, loc, llvmType, operand, acc);
754 case vector::CombiningKind::MUL:
756 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
758 rewriter, loc, llvmType, operand, acc);
761 result = createIntegerReductionComparisonOpLowering<
762 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
763 LLVM::ICmpPredicate::ule);
765 case vector::CombiningKind::MINSI:
766 result = createIntegerReductionComparisonOpLowering<
767 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
768 LLVM::ICmpPredicate::sle);
770 case vector::CombiningKind::MAXUI:
771 result = createIntegerReductionComparisonOpLowering<
772 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
773 LLVM::ICmpPredicate::uge);
775 case vector::CombiningKind::MAXSI:
776 result = createIntegerReductionComparisonOpLowering<
777 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
778 LLVM::ICmpPredicate::sge);
780 case vector::CombiningKind::AND:
782 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
784 rewriter, loc, llvmType, operand, acc);
786 case vector::CombiningKind::OR:
788 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
790 rewriter, loc, llvmType, operand, acc);
792 case vector::CombiningKind::XOR:
794 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
796 rewriter, loc, llvmType, operand, acc);
806 if (!isa<FloatType>(eltType))
809 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
811 reductionOp.getContext(),
814 reductionOp.getContext(),
815 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
816 : LLVM::FastmathFlags::none));
820 if (kind == vector::CombiningKind::ADD) {
821 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
822 ReductionNeutralZero>(
823 rewriter, loc, llvmType, operand, acc, fmf);
824 }
else if (kind == vector::CombiningKind::MUL) {
825 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
826 ReductionNeutralFPOne>(
827 rewriter, loc, llvmType, operand, acc, fmf);
828 }
else if (kind == vector::CombiningKind::MINIMUMF) {
830 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
831 rewriter, loc, llvmType, operand, acc, fmf);
832 }
else if (kind == vector::CombiningKind::MAXIMUMF) {
834 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
835 rewriter, loc, llvmType, operand, acc, fmf);
836 }
else if (kind == vector::CombiningKind::MINNUMF) {
837 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
838 rewriter, loc, llvmType, operand, acc, fmf);
839 }
else if (kind == vector::CombiningKind::MAXNUMF) {
840 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
841 rewriter, loc, llvmType, operand, acc, fmf);
850 const bool reassociateFPReductions;
861 template <
class MaskedOp>
862 class VectorMaskOpConversionBase
868 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
871 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
874 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
878 virtual LogicalResult
879 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
880 vector::MaskableOpInterface maskableOp,
884 class MaskedReductionOpConversion
885 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
888 using VectorMaskOpConversionBase<
889 vector::ReductionOp>::VectorMaskOpConversionBase;
891 LogicalResult matchAndRewriteMaskableOp(
892 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
894 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
895 auto kind = reductionOp.getKind();
896 Type eltType = reductionOp.getDest().getType();
898 Value operand = reductionOp.getVector();
899 Value acc = reductionOp.getAcc();
900 Location loc = reductionOp.getLoc();
902 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
904 reductionOp.getContext(),
909 case vector::CombiningKind::ADD:
910 result = lowerPredicatedReductionWithStartValue<
911 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
912 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
915 case vector::CombiningKind::MUL:
916 result = lowerPredicatedReductionWithStartValue<
917 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
918 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
922 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
923 ReductionNeutralUIntMax>(
924 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
926 case vector::CombiningKind::MINSI:
927 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
928 ReductionNeutralSIntMax>(
929 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
931 case vector::CombiningKind::MAXUI:
932 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
933 ReductionNeutralUIntMin>(
934 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
936 case vector::CombiningKind::MAXSI:
937 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
938 ReductionNeutralSIntMin>(
939 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
941 case vector::CombiningKind::AND:
942 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
943 ReductionNeutralAllOnes>(
944 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
946 case vector::CombiningKind::OR:
947 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
948 ReductionNeutralZero>(
949 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
951 case vector::CombiningKind::XOR:
952 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
953 ReductionNeutralZero>(
954 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
956 case vector::CombiningKind::MINNUMF:
957 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
958 ReductionNeutralFPMax>(
959 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
961 case vector::CombiningKind::MAXNUMF:
962 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
963 ReductionNeutralFPMin>(
964 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
966 case CombiningKind::MAXIMUMF:
967 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
968 MaskNeutralFMaximum>(
969 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
971 case CombiningKind::MINIMUMF:
972 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
973 MaskNeutralFMinimum>(
974 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
979 rewriter.replaceOp(maskOp, result);
984 class VectorShuffleOpConversion
990 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
992 auto loc = shuffleOp->getLoc();
993 auto v1Type = shuffleOp.getV1VectorType();
994 auto v2Type = shuffleOp.getV2VectorType();
995 auto vectorType = shuffleOp.getResultVectorType();
997 auto maskArrayAttr = shuffleOp.getMask();
1004 int64_t rank = vectorType.getRank();
1006 bool wellFormed0DCase =
1007 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1008 bool wellFormedNDCase =
1009 v1Type.getRank() == rank && v2Type.getRank() == rank;
1010 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1015 if (rank <= 1 && v1Type == v2Type) {
1016 Value llvmShuffleOp = rewriter.
create<LLVM::ShuffleVectorOp>(
1017 loc, adaptor.getV1(), adaptor.getV2(),
1018 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1019 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
1024 int64_t v1Dim = v1Type.getDimSize(0);
1026 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1027 eltType = arrayType.getElementType();
1029 eltType = cast<VectorType>(llvmType).getElementType();
1030 Value insert = rewriter.
create<LLVM::UndefOp>(loc, llvmType);
1033 int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1034 Value value = adaptor.getV1();
1035 if (extPos >= v1Dim) {
1037 value = adaptor.getV2();
1040 eltType, rank, extPos);
1041 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1042 llvmType, rank, insPos++);
1049 class VectorExtractElementOpConversion
1056 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1058 auto vectorType = extractEltOp.getSourceVectorType();
1059 auto llvmType = typeConverter->
convertType(vectorType.getElementType());
1065 if (vectorType.getRank() == 0) {
1066 Location loc = extractEltOp.getLoc();
1068 auto zero = rewriter.
create<LLVM::ConstantOp>(
1072 extractEltOp, llvmType, adaptor.getVector(), zero);
1077 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1082 class VectorExtractOpConversion
1088 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1090 auto loc = extractOp->getLoc();
1091 auto resultType = extractOp.getResult().getType();
1092 auto llvmResultType = typeConverter->
convertType(resultType);
1094 if (!llvmResultType)
1098 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1102 if (position.empty()) {
1103 rewriter.
replaceOp(extractOp, adaptor.getVector());
1108 if (isa<VectorType>(resultType)) {
1109 if (extractOp.hasDynamicPosition())
1112 Value extracted = rewriter.
create<LLVM::ExtractValueOp>(
1114 rewriter.
replaceOp(extractOp, extracted);
1119 Value extracted = adaptor.getVector();
1120 if (position.size() > 1) {
1121 if (extractOp.hasDynamicPosition())
1126 extracted = rewriter.
create<LLVM::ExtractValueOp>(loc, extracted,
1157 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1159 VectorType vType = fmaOp.getVectorType();
1160 if (vType.getRank() > 1)
1164 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1169 class VectorInsertElementOpConversion
1175 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1177 auto vectorType = insertEltOp.getDestVectorType();
1178 auto llvmType = typeConverter->
convertType(vectorType);
1184 if (vectorType.getRank() == 0) {
1185 Location loc = insertEltOp.getLoc();
1187 auto zero = rewriter.
create<LLVM::ConstantOp>(
1191 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1196 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1197 adaptor.getPosition());
1202 class VectorInsertOpConversion
1208 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1210 auto loc = insertOp->getLoc();
1211 auto sourceType = insertOp.getSourceType();
1212 auto destVectorType = insertOp.getDestVectorType();
1213 auto llvmResultType = typeConverter->
convertType(destVectorType);
1215 if (!llvmResultType)
1219 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1224 if (position.empty()) {
1225 rewriter.
replaceOp(insertOp, adaptor.getSource());
1230 if (isa<VectorType>(sourceType)) {
1231 if (insertOp.hasDynamicPosition())
1234 Value inserted = rewriter.
create<LLVM::InsertValueOp>(
1235 loc, adaptor.getDest(), adaptor.getSource(),
getAsIntegers(position));
1241 Value extracted = adaptor.getDest();
1242 auto oneDVectorType = destVectorType;
1243 if (position.size() > 1) {
1244 if (insertOp.hasDynamicPosition())
1248 extracted = rewriter.
create<LLVM::ExtractValueOp>(
1253 Value inserted = rewriter.
create<LLVM::InsertElementOp>(
1254 loc, typeConverter->
convertType(oneDVectorType), extracted,
1255 adaptor.getSource(),
getAsLLVMValue(rewriter, loc, position.back()));
1258 if (position.size() > 1) {
1259 if (insertOp.hasDynamicPosition())
1262 inserted = rewriter.
create<LLVM::InsertValueOp>(
1263 loc, adaptor.getDest(), inserted,
1273 struct VectorScalableInsertOpLowering
1279 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1282 insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1288 struct VectorScalableExtractOpLowering
1294 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1297 extOp, typeConverter->
convertType(extOp.getResultVectorType()),
1298 adaptor.getSource(), adaptor.getPos());
1331 setHasBoundedRewriteRecursion();
1334 LogicalResult matchAndRewrite(FMAOp op,
1336 auto vType = op.getVectorType();
1337 if (vType.getRank() < 2)
1341 auto elemType = vType.getElementType();
1344 Value desc = rewriter.
create<vector::SplatOp>(loc, vType, zero);
1345 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1346 Value extrLHS = rewriter.
create<ExtractOp>(loc, op.getLhs(), i);
1347 Value extrRHS = rewriter.
create<ExtractOp>(loc, op.getRhs(), i);
1348 Value extrACC = rewriter.
create<ExtractOp>(loc, op.getAcc(), i);
1349 Value fma = rewriter.
create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1350 desc = rewriter.
create<InsertOp>(loc, fma, desc, i);
1359 static std::optional<SmallVector<int64_t, 4>>
1360 computeContiguousStrides(MemRefType memRefType) {
1364 return std::nullopt;
1365 if (!strides.empty() && strides.back() != 1)
1366 return std::nullopt;
1368 if (memRefType.getLayout().isIdentity())
1375 auto sizes = memRefType.getShape();
1376 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
1377 if (ShapedType::isDynamic(sizes[index + 1]) ||
1378 ShapedType::isDynamic(strides[index]) ||
1379 ShapedType::isDynamic(strides[index + 1]))
1380 return std::nullopt;
1381 if (strides[index] != strides[index + 1] * sizes[index + 1])
1382 return std::nullopt;
1387 class VectorTypeCastOpConversion
1393 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1395 auto loc = castOp->getLoc();
1396 MemRefType sourceMemRefType =
1397 cast<MemRefType>(castOp.getOperand().getType());
1398 MemRefType targetMemRefType = castOp.getType();
1401 if (!sourceMemRefType.hasStaticShape() ||
1402 !targetMemRefType.hasStaticShape())
1405 auto llvmSourceDescriptorTy =
1406 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1407 if (!llvmSourceDescriptorTy)
1411 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1413 if (!llvmTargetDescriptorTy)
1417 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1420 auto targetStrides = computeContiguousStrides(targetMemRefType);
1424 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1432 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1433 desc.setAllocatedPtr(rewriter, loc, allocated);
1436 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1437 desc.setAlignedPtr(rewriter, loc, ptr);
1440 auto zero = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, attr);
1441 desc.setOffset(rewriter, loc, zero);
1444 for (
const auto &indexedSize :
1446 int64_t index = indexedSize.index();
1449 auto size = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1450 desc.setSize(rewriter, loc, index, size);
1452 (*targetStrides)[index]);
1453 auto stride = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1454 desc.setStride(rewriter, loc, index, stride);
1464 class VectorCreateMaskOpRewritePattern
1467 explicit VectorCreateMaskOpRewritePattern(
MLIRContext *context,
1468 bool enableIndexOpt)
1470 force32BitVectorIndices(enableIndexOpt) {}
1472 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1474 auto dstType = op.getType();
1475 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1477 IntegerType idxType =
1480 Value indices = rewriter.
create<LLVM::StepVectorOp>(
1486 Value comp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1493 const bool force32BitVectorIndices;
1514 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1516 auto parent =
printOp->getParentOfType<ModuleOp>();
1522 if (
auto value = adaptor.getSource()) {
1528 if (failed(emitScalarPrint(rewriter, parent, loc,
printType, value)))
1532 auto punct =
printOp.getPunctuation();
1533 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1535 *stringLiteral, *getTypeConverter(),
1537 }
else if (punct != PrintPunctuation::NoPunctuation) {
1538 emitCall(rewriter,
printOp->getLoc(), [&] {
1540 case PrintPunctuation::Close:
1541 return LLVM::lookupOrCreatePrintCloseFn(parent);
1542 case PrintPunctuation::Open:
1543 return LLVM::lookupOrCreatePrintOpenFn(parent);
1544 case PrintPunctuation::Comma:
1545 return LLVM::lookupOrCreatePrintCommaFn(parent);
1546 case PrintPunctuation::NewLine:
1547 return LLVM::lookupOrCreatePrintNewlineFn(parent);
1549 llvm_unreachable(
"unexpected punctuation");
1559 enum class PrintConversion {
1570 Value value)
const {
1582 conversion = PrintConversion::Bitcast16;
1585 conversion = PrintConversion::Bitcast16;
1589 }
else if (
auto intTy = dyn_cast<IntegerType>(
printType)) {
1593 unsigned width = intTy.getWidth();
1594 if (intTy.isUnsigned()) {
1597 conversion = PrintConversion::ZeroExt64;
1603 assert(intTy.isSignless() || intTy.isSigned());
1608 conversion = PrintConversion::ZeroExt64;
1609 else if (width < 64)
1610 conversion = PrintConversion::SignExt64;
1620 switch (conversion) {
1621 case PrintConversion::ZeroExt64:
1622 value = rewriter.
create<arith::ExtUIOp>(
1625 case PrintConversion::SignExt64:
1626 value = rewriter.
create<arith::ExtSIOp>(
1629 case PrintConversion::Bitcast16:
1630 value = rewriter.
create<LLVM::BitcastOp>(
1636 emitCall(rewriter, loc, printer, value);
1654 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1656 VectorType resultType = cast<VectorType>(splatOp.getType());
1657 if (resultType.getRank() > 1)
1661 auto vectorType = typeConverter->
convertType(splatOp.getType());
1662 Value undef = rewriter.
create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1663 auto zero = rewriter.
create<LLVM::ConstantOp>(
1669 if (resultType.getRank() == 0) {
1671 splatOp, vectorType, undef, adaptor.getInput(), zero);
1676 auto v = rewriter.
create<LLVM::InsertElementOp>(
1677 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1679 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1696 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1698 VectorType resultType = splatOp.getType();
1699 if (resultType.getRank() <= 1)
1703 auto loc = splatOp.getLoc();
1704 auto vectorTypeInfo =
1706 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1707 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1708 if (!llvmNDVectorTy || !llvm1DVectorTy)
1712 Value desc = rewriter.
create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1716 Value vdesc = rewriter.
create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1717 auto zero = rewriter.
create<LLVM::ConstantOp>(
1720 Value v = rewriter.
create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1721 adaptor.getInput(), zero);
1724 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1726 v = rewriter.
create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1731 desc = rewriter.
create<LLVM::InsertValueOp>(loc, desc, v, position);
1740 struct VectorInterleaveOpLowering
1745 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1747 VectorType resultType = interleaveOp.getResultVectorType();
1749 if (resultType.getRank() != 1)
1751 "InterleaveOp not rank 1");
1753 if (resultType.isScalable()) {
1755 interleaveOp, typeConverter->
convertType(resultType),
1756 adaptor.getLhs(), adaptor.getRhs());
1763 int64_t resultVectorSize = resultType.getNumElements();
1765 interleaveShuffleMask.reserve(resultVectorSize);
1766 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1767 interleaveShuffleMask.push_back(i);
1768 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1771 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1772 interleaveShuffleMask);
1779 struct VectorDeinterleaveOpLowering
1784 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1786 VectorType resultType = deinterleaveOp.getResultVectorType();
1787 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1788 auto loc = deinterleaveOp.getLoc();
1792 if (resultType.getRank() != 1)
1794 "DeinterleaveOp not rank 1");
1796 if (resultType.isScalable()) {
1797 auto llvmTypeConverter = this->getTypeConverter();
1798 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1799 auto packedOpResults =
1800 llvmTypeConverter->packOperationResults(deinterleaveResults);
1801 auto intrinsic = rewriter.
create<LLVM::vector_deinterleave2>(
1802 loc, packedOpResults, adaptor.getSource());
1804 auto evenResult = rewriter.
create<LLVM::ExtractValueOp>(
1805 loc, intrinsic->getResult(0), 0);
1806 auto oddResult = rewriter.
create<LLVM::ExtractValueOp>(
1807 loc, intrinsic->getResult(0), 1);
1816 int64_t resultVectorSize = resultType.getNumElements();
1820 evenShuffleMask.reserve(resultVectorSize);
1821 oddShuffleMask.reserve(resultVectorSize);
1823 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1825 evenShuffleMask.push_back(i);
1827 oddShuffleMask.push_back(i);
1830 auto poison = rewriter.
create<LLVM::PoisonOp>(loc, sourceType);
1831 auto evenShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1832 loc, adaptor.getSource(), poison, evenShuffleMask);
1833 auto oddShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1834 loc, adaptor.getSource(), poison, oddShuffleMask);
1842 struct VectorFromElementsLowering
1847 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1849 Location loc = fromElementsOp.getLoc();
1850 VectorType vectorType = fromElementsOp.getType();
1853 if (vectorType.getRank() > 1)
1855 "rank > 1 vectors are not supported");
1857 Value result = rewriter.
create<LLVM::UndefOp>(loc, llvmType);
1859 result = rewriter.
create<vector::InsertOp>(loc, val, result, idx);
1860 rewriter.
replaceOp(fromElementsOp, result);
1870 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1883 bool reassociateFPReductions,
bool force32BitVectorIndices) {
1885 patterns.
add<VectorFMAOpNDRewritePattern>(ctx);
1887 patterns.
add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1888 patterns.
add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1889 patterns.
add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1890 VectorExtractElementOpConversion, VectorExtractOpConversion,
1891 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1892 VectorInsertOpConversion, VectorPrintOpConversion,
1893 VectorTypeCastOpConversion, VectorScaleOpConversion,
1894 VectorLoadStoreConversion<vector::LoadOp>,
1895 VectorLoadStoreConversion<vector::MaskedLoadOp>,
1896 VectorLoadStoreConversion<vector::StoreOp>,
1897 VectorLoadStoreConversion<vector::MaskedStoreOp>,
1898 VectorGatherOpConversion, VectorScatterOpConversion,
1899 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1900 VectorSplatOpLowering, VectorSplatNdOpLowering,
1901 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1902 MaskedReductionOpConversion, VectorInterleaveOpLowering,
1903 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1904 VectorStepOpLowering>(converter);
1911 patterns.
add<VectorMatmulOpConversion>(converter);
1912 patterns.
add<VectorFlatTransposeOpConversion>(converter);
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
static VectorType reducedVectorTypeBack(VectorType tp)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls.
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp)
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...