30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
39 assert((tp.getRank() > 1) &&
"unlowerable vector type");
41 tp.getScalableDims().take_back());
49 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
52 auto constant = rewriter.
create<LLVM::ConstantOp>(
55 return rewriter.
create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58 return rewriter.
create<LLVM::InsertValueOp>(loc, val1, val2, pos);
64 Value val,
Type llvmType, int64_t rank, int64_t pos) {
67 auto constant = rewriter.
create<LLVM::ConstantOp>(
70 return rewriter.
create<LLVM::ExtractElementOp>(loc, llvmType, val,
73 return rewriter.
create<LLVM::ExtractValueOp>(loc, val, pos);
78 MemRefType memrefType,
unsigned &align) {
79 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
85 llvm::LLVMContext llvmContext;
104 MemRefType memRefType,
Value llvmMemref,
Value base,
105 Value index, VectorType vectorType) {
107 "unsupported memref type");
108 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
112 vectorType.getScalableDims()[0]);
113 return rewriter.
create<LLVM::GEPOp>(
114 loc, ptrsType, typeConverter.
convertType(memRefType.getElementType()),
122 if (
auto attr = foldResult.dyn_cast<
Attribute>()) {
123 auto intAttr = cast<IntegerAttr>(attr);
124 return builder.
create<LLVM::ConstantOp>(loc, intAttr).getResult();
127 return foldResult.get<
Value>();
133 using VectorScaleOpConversion =
137 class VectorBitCastOpConversion
143 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
146 VectorType resultTy = bitCastOp.getResultVectorType();
147 if (resultTy.getRank() > 1)
149 Type newResultTy = typeConverter->convertType(resultTy);
151 adaptor.getOperands()[0]);
158 class VectorMatmulOpConversion
164 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
167 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
168 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
169 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
176 class VectorFlatTransposeOpConversion
182 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
185 transOp, typeConverter->convertType(transOp.getRes().getType()),
186 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
194 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
195 vector::LoadOpAdaptor adaptor,
196 VectorType vectorTy,
Value ptr,
unsigned align,
200 loadOp.getNontemporal());
203 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
204 vector::MaskedLoadOpAdaptor adaptor,
205 VectorType vectorTy,
Value ptr,
unsigned align,
208 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
211 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
212 vector::StoreOpAdaptor adaptor,
213 VectorType vectorTy,
Value ptr,
unsigned align,
217 storeOp.getNontemporal());
220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
221 vector::MaskedStoreOpAdaptor adaptor,
222 VectorType vectorTy,
Value ptr,
unsigned align,
225 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
230 template <
class LoadOrStoreOp>
236 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
237 typename LoadOrStoreOp::Adaptor adaptor,
240 VectorType vectorTy = loadOrStoreOp.getVectorType();
241 if (vectorTy.getRank() > 1)
244 auto loc = loadOrStoreOp->getLoc();
245 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
253 auto vtype = cast<VectorType>(
254 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
255 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
256 adaptor.getIndices(), rewriter);
257 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
264 class VectorGatherOpConversion
270 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
272 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
273 assert(memRefType &&
"The base should be bufferized");
278 auto loc = gather->getLoc();
285 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
286 adaptor.getIndices(), rewriter);
287 Value base = adaptor.getBase();
289 auto llvmNDVectorTy = adaptor.getIndexVec().
getType();
291 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
292 auto vType = gather.getVectorType();
295 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
296 base, ptr, adaptor.getIndexVec(), vType);
299 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
305 auto callback = [align, memRefType, base, ptr, loc, &rewriter,
306 &typeConverter](
Type llvm1DVectorTy,
310 rewriter, loc, typeConverter, memRefType, base, ptr,
311 vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
313 return rewriter.create<LLVM::masked_gather>(
314 loc, llvm1DVectorTy, ptrs, vectorOperands[1],
315 vectorOperands[2], rewriter.getI32IntegerAttr(align));
318 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
320 gather, vectorOperands, *getTypeConverter(), callback, rewriter);
325 class VectorScatterOpConversion
331 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
333 auto loc = scatter->getLoc();
334 MemRefType memRefType = scatter.getMemRefType();
345 VectorType vType = scatter.getVectorType();
346 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
347 adaptor.getIndices(), rewriter);
349 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
350 adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
354 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
361 class VectorExpandLoadOpConversion
367 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
369 auto loc = expand->getLoc();
370 MemRefType memRefType = expand.getMemRefType();
373 auto vtype = typeConverter->
convertType(expand.getVectorType());
374 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
375 adaptor.getIndices(), rewriter);
378 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
384 class VectorCompressStoreOpConversion
390 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
392 auto loc = compress->getLoc();
393 MemRefType memRefType = compress.getMemRefType();
396 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397 adaptor.getIndices(), rewriter);
400 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
406 class ReductionNeutralZero {};
407 class ReductionNeutralIntOne {};
408 class ReductionNeutralFPOne {};
409 class ReductionNeutralAllOnes {};
410 class ReductionNeutralSIntMin {};
411 class ReductionNeutralUIntMin {};
412 class ReductionNeutralSIntMax {};
413 class ReductionNeutralUIntMax {};
414 class ReductionNeutralFPMin {};
415 class ReductionNeutralFPMax {};
418 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
421 return rewriter.
create<LLVM::ConstantOp>(loc, llvmType,
426 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
429 return rewriter.
create<LLVM::ConstantOp>(
434 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
437 return rewriter.
create<LLVM::ConstantOp>(
442 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
445 return rewriter.
create<LLVM::ConstantOp>(
452 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
455 return rewriter.
create<LLVM::ConstantOp>(
462 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
465 return rewriter.
create<LLVM::ConstantOp>(
472 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
475 return rewriter.
create<LLVM::ConstantOp>(
482 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
485 return rewriter.
create<LLVM::ConstantOp>(
492 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
495 auto floatType = cast<FloatType>(llvmType);
496 return rewriter.
create<LLVM::ConstantOp>(
499 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
504 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
507 auto floatType = cast<FloatType>(llvmType);
508 return rewriter.
create<LLVM::ConstantOp>(
511 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
517 template <
class ReductionNeutral>
524 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
533 VectorType vType = cast<VectorType>(llvmType);
534 auto vShape = vType.getShape();
535 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
537 Value baseVecLength = rewriter.
create<LLVM::ConstantOp>(
541 if (!vType.getScalableDims()[0])
542 return baseVecLength;
545 Value vScale = rewriter.
create<vector::VectorScaleOp>(loc);
548 Value scalableVecLength =
549 rewriter.
create<arith::MulIOp>(loc, baseVecLength, vScale);
550 return scalableVecLength;
557 template <
class LLVMRedIntrinOp,
class ScalarOp>
558 static Value createIntegerReductionArithmeticOpLowering(
562 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
565 result = rewriter.
create<ScalarOp>(loc, accumulator, result);
573 template <
class LLVMRedIntrinOp>
574 static Value createIntegerReductionComparisonOpLowering(
576 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
577 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
580 rewriter.
create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
581 result = rewriter.
create<LLVM::SelectOp>(loc, cmp, accumulator, result);
587 template <
typename Source>
588 struct VectorToScalarMapper;
590 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
591 using Type = LLVM::MaximumOp;
594 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
595 using Type = LLVM::MinimumOp;
598 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
599 using Type = LLVM::MaxNumOp;
602 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
603 using Type = LLVM::MinNumOp;
607 template <
class LLVMRedIntrinOp>
608 static Value createFPReductionComparisonOpLowering(
610 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
612 rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
616 rewriter.
create<
typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
617 loc, result, accumulator);
624 class MaskNeutralFMaximum {};
625 class MaskNeutralFMinimum {};
629 getMaskNeutralValue(MaskNeutralFMaximum,
630 const llvm::fltSemantics &floatSemantics) {
631 return llvm::APFloat::getSmallest(floatSemantics,
true);
635 getMaskNeutralValue(MaskNeutralFMinimum,
636 const llvm::fltSemantics &floatSemantics) {
637 return llvm::APFloat::getLargest(floatSemantics,
false);
641 template <
typename MaskNeutral>
645 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
646 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
648 return rewriter.
create<LLVM::ConstantOp>(loc, vectorType, denseValue);
655 template <
class LLVMRedIntrinOp,
class MaskNeutral>
660 Value mask, LLVM::FastmathFlagsAttr fmf) {
661 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
662 rewriter, loc, llvmType, vectorOperand.
getType());
663 const Value selectedVectorByMask = rewriter.
create<LLVM::SelectOp>(
664 loc, mask, vectorOperand, vectorMaskNeutral);
665 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
666 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
669 template <
class LLVMRedIntrinOp,
class ReductionNeutral>
673 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
674 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
675 llvmType, accumulator);
676 return rewriter.
create<LLVMRedIntrinOp>(loc, llvmType,
684 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
689 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
690 llvmType, accumulator);
691 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
696 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
697 static Value lowerPredicatedReductionWithStartValue(
700 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
701 llvmType, accumulator);
703 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
704 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
706 vectorOperand, mask, vectorLength);
709 template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
710 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
711 static Value lowerPredicatedReductionWithStartValue(
715 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
716 IntReductionNeutral>(
717 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
720 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
722 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
726 class VectorReductionOpConversion
730 bool reassociateFPRed)
732 reassociateFPReductions(reassociateFPRed) {}
735 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
737 auto kind = reductionOp.getKind();
738 Type eltType = reductionOp.getDest().getType();
740 Value operand = adaptor.getVector();
741 Value acc = adaptor.getAcc();
742 Location loc = reductionOp.getLoc();
748 case vector::CombiningKind::ADD:
750 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
752 rewriter, loc, llvmType, operand, acc);
754 case vector::CombiningKind::MUL:
756 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
758 rewriter, loc, llvmType, operand, acc);
761 result = createIntegerReductionComparisonOpLowering<
762 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
763 LLVM::ICmpPredicate::ule);
765 case vector::CombiningKind::MINSI:
766 result = createIntegerReductionComparisonOpLowering<
767 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
768 LLVM::ICmpPredicate::sle);
770 case vector::CombiningKind::MAXUI:
771 result = createIntegerReductionComparisonOpLowering<
772 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
773 LLVM::ICmpPredicate::uge);
775 case vector::CombiningKind::MAXSI:
776 result = createIntegerReductionComparisonOpLowering<
777 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
778 LLVM::ICmpPredicate::sge);
780 case vector::CombiningKind::AND:
782 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
784 rewriter, loc, llvmType, operand, acc);
786 case vector::CombiningKind::OR:
788 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
790 rewriter, loc, llvmType, operand, acc);
792 case vector::CombiningKind::XOR:
794 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
796 rewriter, loc, llvmType, operand, acc);
806 if (!isa<FloatType>(eltType))
809 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
811 reductionOp.getContext(),
814 reductionOp.getContext(),
815 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
816 : LLVM::FastmathFlags::none));
820 if (kind == vector::CombiningKind::ADD) {
821 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
822 ReductionNeutralZero>(
823 rewriter, loc, llvmType, operand, acc, fmf);
824 }
else if (kind == vector::CombiningKind::MUL) {
825 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
826 ReductionNeutralFPOne>(
827 rewriter, loc, llvmType, operand, acc, fmf);
828 }
else if (kind == vector::CombiningKind::MINIMUMF) {
830 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
831 rewriter, loc, llvmType, operand, acc, fmf);
832 }
else if (kind == vector::CombiningKind::MAXIMUMF) {
834 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
835 rewriter, loc, llvmType, operand, acc, fmf);
836 }
else if (kind == vector::CombiningKind::MINNUMF) {
837 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
838 rewriter, loc, llvmType, operand, acc, fmf);
839 }
else if (kind == vector::CombiningKind::MAXNUMF) {
840 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
841 rewriter, loc, llvmType, operand, acc, fmf);
850 const bool reassociateFPReductions;
861 template <
class MaskedOp>
862 class VectorMaskOpConversionBase
868 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
871 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
874 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
878 virtual LogicalResult
879 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
880 vector::MaskableOpInterface maskableOp,
884 class MaskedReductionOpConversion
885 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
888 using VectorMaskOpConversionBase<
889 vector::ReductionOp>::VectorMaskOpConversionBase;
891 LogicalResult matchAndRewriteMaskableOp(
892 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
894 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
895 auto kind = reductionOp.getKind();
896 Type eltType = reductionOp.getDest().getType();
898 Value operand = reductionOp.getVector();
899 Value acc = reductionOp.getAcc();
900 Location loc = reductionOp.getLoc();
902 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
904 reductionOp.getContext(),
909 case vector::CombiningKind::ADD:
910 result = lowerPredicatedReductionWithStartValue<
911 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
912 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
915 case vector::CombiningKind::MUL:
916 result = lowerPredicatedReductionWithStartValue<
917 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
918 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
922 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
923 ReductionNeutralUIntMax>(
924 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
926 case vector::CombiningKind::MINSI:
927 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
928 ReductionNeutralSIntMax>(
929 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
931 case vector::CombiningKind::MAXUI:
932 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
933 ReductionNeutralUIntMin>(
934 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
936 case vector::CombiningKind::MAXSI:
937 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
938 ReductionNeutralSIntMin>(
939 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
941 case vector::CombiningKind::AND:
942 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
943 ReductionNeutralAllOnes>(
944 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
946 case vector::CombiningKind::OR:
947 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
948 ReductionNeutralZero>(
949 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
951 case vector::CombiningKind::XOR:
952 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
953 ReductionNeutralZero>(
954 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
956 case vector::CombiningKind::MINNUMF:
957 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
958 ReductionNeutralFPMax>(
959 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
961 case vector::CombiningKind::MAXNUMF:
962 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
963 ReductionNeutralFPMin>(
964 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
966 case CombiningKind::MAXIMUMF:
967 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
968 MaskNeutralFMaximum>(
969 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
971 case CombiningKind::MINIMUMF:
972 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
973 MaskNeutralFMinimum>(
974 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
979 rewriter.replaceOp(maskOp, result);
984 class VectorShuffleOpConversion
990 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
992 auto loc = shuffleOp->getLoc();
993 auto v1Type = shuffleOp.getV1VectorType();
994 auto v2Type = shuffleOp.getV2VectorType();
995 auto vectorType = shuffleOp.getResultVectorType();
1004 int64_t rank = vectorType.getRank();
1006 bool wellFormed0DCase =
1007 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1008 bool wellFormedNDCase =
1009 v1Type.getRank() == rank && v2Type.getRank() == rank;
1010 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1015 if (rank <= 1 && v1Type == v2Type) {
1016 Value llvmShuffleOp = rewriter.
create<LLVM::ShuffleVectorOp>(
1017 loc, adaptor.getV1(), adaptor.getV2(),
1018 llvm::to_vector_of<int32_t>(mask));
1019 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
1024 int64_t v1Dim = v1Type.getDimSize(0);
1026 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1027 eltType = arrayType.getElementType();
1029 eltType = cast<VectorType>(llvmType).getElementType();
1030 Value insert = rewriter.
create<LLVM::UndefOp>(loc, llvmType);
1032 for (int64_t extPos : mask) {
1033 Value value = adaptor.getV1();
1034 if (extPos >= v1Dim) {
1036 value = adaptor.getV2();
1039 eltType, rank, extPos);
1040 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1041 llvmType, rank, insPos++);
1048 class VectorExtractElementOpConversion
1055 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1057 auto vectorType = extractEltOp.getSourceVectorType();
1058 auto llvmType = typeConverter->
convertType(vectorType.getElementType());
1064 if (vectorType.getRank() == 0) {
1065 Location loc = extractEltOp.getLoc();
1067 auto zero = rewriter.
create<LLVM::ConstantOp>(
1071 extractEltOp, llvmType, adaptor.getVector(), zero);
1076 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1081 class VectorExtractOpConversion
1087 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1089 auto loc = extractOp->getLoc();
1090 auto resultType = extractOp.getResult().getType();
1091 auto llvmResultType = typeConverter->
convertType(resultType);
1093 if (!llvmResultType)
1097 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1101 if (position.empty()) {
1102 rewriter.
replaceOp(extractOp, adaptor.getVector());
1108 if (isa<VectorType>(resultType) &&
1110 static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
1111 if (extractOp.hasDynamicPosition())
1114 Value extracted = rewriter.
create<LLVM::ExtractValueOp>(
1116 rewriter.
replaceOp(extractOp, extracted);
1121 Value extracted = adaptor.getVector();
1122 if (position.size() > 1) {
1123 if (extractOp.hasDynamicPosition())
1128 extracted = rewriter.
create<LLVM::ExtractValueOp>(loc, extracted,
1159 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1161 VectorType vType = fmaOp.getVectorType();
1162 if (vType.getRank() > 1)
1166 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1171 class VectorInsertElementOpConversion
1177 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1179 auto vectorType = insertEltOp.getDestVectorType();
1180 auto llvmType = typeConverter->
convertType(vectorType);
1186 if (vectorType.getRank() == 0) {
1187 Location loc = insertEltOp.getLoc();
1189 auto zero = rewriter.
create<LLVM::ConstantOp>(
1193 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1198 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1199 adaptor.getPosition());
1204 class VectorInsertOpConversion
1210 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1212 auto loc = insertOp->getLoc();
1213 auto sourceType = insertOp.getSourceType();
1214 auto destVectorType = insertOp.getDestVectorType();
1215 auto llvmResultType = typeConverter->
convertType(destVectorType);
1217 if (!llvmResultType)
1221 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1226 if (position.empty()) {
1227 rewriter.
replaceOp(insertOp, adaptor.getSource());
1232 if (isa<VectorType>(sourceType)) {
1233 if (insertOp.hasDynamicPosition())
1236 Value inserted = rewriter.
create<LLVM::InsertValueOp>(
1237 loc, adaptor.getDest(), adaptor.getSource(),
getAsIntegers(position));
1243 Value extracted = adaptor.getDest();
1244 auto oneDVectorType = destVectorType;
1245 if (position.size() > 1) {
1246 if (insertOp.hasDynamicPosition())
1250 extracted = rewriter.
create<LLVM::ExtractValueOp>(
1255 Value inserted = rewriter.
create<LLVM::InsertElementOp>(
1256 loc, typeConverter->
convertType(oneDVectorType), extracted,
1257 adaptor.getSource(),
getAsLLVMValue(rewriter, loc, position.back()));
1260 if (position.size() > 1) {
1261 if (insertOp.hasDynamicPosition())
1264 inserted = rewriter.
create<LLVM::InsertValueOp>(
1265 loc, adaptor.getDest(), inserted,
1275 struct VectorScalableInsertOpLowering
1281 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1284 insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1290 struct VectorScalableExtractOpLowering
1296 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1299 extOp, typeConverter->
convertType(extOp.getResultVectorType()),
1300 adaptor.getSource(), adaptor.getPos());
1333 setHasBoundedRewriteRecursion();
1336 LogicalResult matchAndRewrite(FMAOp op,
1338 auto vType = op.getVectorType();
1339 if (vType.getRank() < 2)
1342 auto loc = op.getLoc();
1343 auto elemType = vType.getElementType();
1346 Value desc = rewriter.
create<vector::SplatOp>(loc, vType, zero);
1347 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1348 Value extrLHS = rewriter.
create<ExtractOp>(loc, op.getLhs(), i);
1349 Value extrRHS = rewriter.
create<ExtractOp>(loc, op.getRhs(), i);
1350 Value extrACC = rewriter.
create<ExtractOp>(loc, op.getAcc(), i);
1351 Value fma = rewriter.
create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1352 desc = rewriter.
create<InsertOp>(loc, fma, desc, i);
1361 static std::optional<SmallVector<int64_t, 4>>
1362 computeContiguousStrides(MemRefType memRefType) {
1366 return std::nullopt;
1367 if (!strides.empty() && strides.back() != 1)
1368 return std::nullopt;
1370 if (memRefType.getLayout().isIdentity())
1377 auto sizes = memRefType.getShape();
1378 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
1379 if (ShapedType::isDynamic(sizes[index + 1]) ||
1380 ShapedType::isDynamic(strides[index]) ||
1381 ShapedType::isDynamic(strides[index + 1]))
1382 return std::nullopt;
1383 if (strides[index] != strides[index + 1] * sizes[index + 1])
1384 return std::nullopt;
1389 class VectorTypeCastOpConversion
1395 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1397 auto loc = castOp->getLoc();
1398 MemRefType sourceMemRefType =
1399 cast<MemRefType>(castOp.getOperand().getType());
1400 MemRefType targetMemRefType = castOp.getType();
1403 if (!sourceMemRefType.hasStaticShape() ||
1404 !targetMemRefType.hasStaticShape())
1407 auto llvmSourceDescriptorTy =
1408 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1409 if (!llvmSourceDescriptorTy)
1413 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1415 if (!llvmTargetDescriptorTy)
1419 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1422 auto targetStrides = computeContiguousStrides(targetMemRefType);
1426 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1434 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1435 desc.setAllocatedPtr(rewriter, loc, allocated);
1438 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1439 desc.setAlignedPtr(rewriter, loc, ptr);
1442 auto zero = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, attr);
1443 desc.setOffset(rewriter, loc, zero);
1446 for (
const auto &indexedSize :
1448 int64_t index = indexedSize.index();
1451 auto size = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1452 desc.setSize(rewriter, loc, index, size);
1454 (*targetStrides)[index]);
1455 auto stride = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1456 desc.setStride(rewriter, loc, index, stride);
1466 class VectorCreateMaskOpRewritePattern
1469 explicit VectorCreateMaskOpRewritePattern(
MLIRContext *context,
1470 bool enableIndexOpt)
1472 force32BitVectorIndices(enableIndexOpt) {}
1474 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1476 auto dstType = op.getType();
1477 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1479 IntegerType idxType =
1481 auto loc = op->getLoc();
1482 Value indices = rewriter.
create<LLVM::StepVectorOp>(
1488 Value comp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1495 const bool force32BitVectorIndices;
1516 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1518 auto parent =
printOp->getParentOfType<ModuleOp>();
1524 if (
auto value = adaptor.getSource()) {
1530 if (failed(emitScalarPrint(rewriter, parent, loc,
printType, value)))
1534 auto punct =
printOp.getPunctuation();
1535 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1537 *stringLiteral, *getTypeConverter(),
1539 }
else if (punct != PrintPunctuation::NoPunctuation) {
1540 emitCall(rewriter,
printOp->getLoc(), [&] {
1542 case PrintPunctuation::Close:
1543 return LLVM::lookupOrCreatePrintCloseFn(parent);
1544 case PrintPunctuation::Open:
1545 return LLVM::lookupOrCreatePrintOpenFn(parent);
1546 case PrintPunctuation::Comma:
1547 return LLVM::lookupOrCreatePrintCommaFn(parent);
1548 case PrintPunctuation::NewLine:
1549 return LLVM::lookupOrCreatePrintNewlineFn(parent);
1551 llvm_unreachable(
"unexpected punctuation");
1561 enum class PrintConversion {
1572 Value value)
const {
1584 conversion = PrintConversion::Bitcast16;
1587 conversion = PrintConversion::Bitcast16;
1591 }
else if (
auto intTy = dyn_cast<IntegerType>(
printType)) {
1595 unsigned width = intTy.getWidth();
1596 if (intTy.isUnsigned()) {
1599 conversion = PrintConversion::ZeroExt64;
1605 assert(intTy.isSignless() || intTy.isSigned());
1610 conversion = PrintConversion::ZeroExt64;
1611 else if (width < 64)
1612 conversion = PrintConversion::SignExt64;
1622 switch (conversion) {
1623 case PrintConversion::ZeroExt64:
1624 value = rewriter.
create<arith::ExtUIOp>(
1627 case PrintConversion::SignExt64:
1628 value = rewriter.
create<arith::ExtSIOp>(
1631 case PrintConversion::Bitcast16:
1632 value = rewriter.
create<LLVM::BitcastOp>(
1638 emitCall(rewriter, loc, printer, value);
1656 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1658 VectorType resultType = cast<VectorType>(splatOp.getType());
1659 if (resultType.getRank() > 1)
1663 auto vectorType = typeConverter->
convertType(splatOp.getType());
1664 Value undef = rewriter.
create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1665 auto zero = rewriter.
create<LLVM::ConstantOp>(
1671 if (resultType.getRank() == 0) {
1673 splatOp, vectorType, undef, adaptor.getInput(), zero);
1678 auto v = rewriter.
create<LLVM::InsertElementOp>(
1679 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1681 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1698 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1700 VectorType resultType = splatOp.getType();
1701 if (resultType.getRank() <= 1)
1705 auto loc = splatOp.getLoc();
1706 auto vectorTypeInfo =
1708 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1709 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1710 if (!llvmNDVectorTy || !llvm1DVectorTy)
1714 Value desc = rewriter.
create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1718 Value vdesc = rewriter.
create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1719 auto zero = rewriter.
create<LLVM::ConstantOp>(
1722 Value v = rewriter.
create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1723 adaptor.getInput(), zero);
1726 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1728 v = rewriter.
create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1733 desc = rewriter.
create<LLVM::InsertValueOp>(loc, desc, v, position);
1742 struct VectorInterleaveOpLowering
1747 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1749 VectorType resultType = interleaveOp.getResultVectorType();
1751 if (resultType.getRank() != 1)
1753 "InterleaveOp not rank 1");
1755 if (resultType.isScalable()) {
1757 interleaveOp, typeConverter->
convertType(resultType),
1758 adaptor.getLhs(), adaptor.getRhs());
1765 int64_t resultVectorSize = resultType.getNumElements();
1767 interleaveShuffleMask.reserve(resultVectorSize);
1768 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1769 interleaveShuffleMask.push_back(i);
1770 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1773 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1774 interleaveShuffleMask);
1781 struct VectorDeinterleaveOpLowering
1786 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1788 VectorType resultType = deinterleaveOp.getResultVectorType();
1789 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1790 auto loc = deinterleaveOp.getLoc();
1794 if (resultType.getRank() != 1)
1796 "DeinterleaveOp not rank 1");
1798 if (resultType.isScalable()) {
1799 auto llvmTypeConverter = this->getTypeConverter();
1800 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1801 auto packedOpResults =
1802 llvmTypeConverter->packOperationResults(deinterleaveResults);
1803 auto intrinsic = rewriter.
create<LLVM::vector_deinterleave2>(
1804 loc, packedOpResults, adaptor.getSource());
1806 auto evenResult = rewriter.
create<LLVM::ExtractValueOp>(
1807 loc, intrinsic->getResult(0), 0);
1808 auto oddResult = rewriter.
create<LLVM::ExtractValueOp>(
1809 loc, intrinsic->getResult(0), 1);
1818 int64_t resultVectorSize = resultType.getNumElements();
1822 evenShuffleMask.reserve(resultVectorSize);
1823 oddShuffleMask.reserve(resultVectorSize);
1825 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1827 evenShuffleMask.push_back(i);
1829 oddShuffleMask.push_back(i);
1832 auto poison = rewriter.
create<LLVM::PoisonOp>(loc, sourceType);
1833 auto evenShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1834 loc, adaptor.getSource(), poison, evenShuffleMask);
1835 auto oddShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1836 loc, adaptor.getSource(), poison, oddShuffleMask);
1844 struct VectorFromElementsLowering
1849 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1851 Location loc = fromElementsOp.getLoc();
1852 VectorType vectorType = fromElementsOp.getType();
1855 if (vectorType.getRank() > 1)
1857 "rank > 1 vectors are not supported");
1859 Value result = rewriter.
create<LLVM::UndefOp>(loc, llvmType);
1861 result = rewriter.
create<vector::InsertOp>(loc, val, result, idx);
1862 rewriter.
replaceOp(fromElementsOp, result);
1868 struct VectorScalableStepOpLowering
1873 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1875 auto resultType = cast<VectorType>(stepOp.getType());
1876 if (!resultType.isScalable()) {
1890 bool reassociateFPReductions,
bool force32BitVectorIndices) {
1892 patterns.
add<VectorFMAOpNDRewritePattern>(ctx);
1895 patterns.
add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1896 patterns.
add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1897 patterns.
add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1898 VectorExtractElementOpConversion, VectorExtractOpConversion,
1899 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1900 VectorInsertOpConversion, VectorPrintOpConversion,
1901 VectorTypeCastOpConversion, VectorScaleOpConversion,
1902 VectorLoadStoreConversion<vector::LoadOp>,
1903 VectorLoadStoreConversion<vector::MaskedLoadOp>,
1904 VectorLoadStoreConversion<vector::StoreOp>,
1905 VectorLoadStoreConversion<vector::MaskedStoreOp>,
1906 VectorGatherOpConversion, VectorScatterOpConversion,
1907 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1908 VectorSplatOpLowering, VectorSplatNdOpLowering,
1909 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1910 MaskedReductionOpConversion, VectorInterleaveOpLowering,
1911 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1912 VectorScalableStepOpLowering>(converter);
1919 patterns.
add<VectorMatmulOpConversion>(converter);
1920 patterns.
add<VectorFlatTransposeOpConversion>(converter);
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
static VectorType reducedVectorTypeBack(VectorType tp)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls.
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp)
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMMatrixConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...