30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
39 assert((tp.getRank() > 1) &&
"unlowerable vector type");
41 tp.getScalableDims().take_back());
49 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
52 auto constant = rewriter.
create<LLVM::ConstantOp>(
55 return rewriter.
create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58 return rewriter.
create<LLVM::InsertValueOp>(loc, val1, val2, pos);
64 Value val,
Type llvmType, int64_t rank, int64_t pos) {
67 auto constant = rewriter.
create<LLVM::ConstantOp>(
70 return rewriter.
create<LLVM::ExtractElementOp>(loc, llvmType, val,
73 return rewriter.
create<LLVM::ExtractValueOp>(loc, val, pos);
78 MemRefType memrefType,
unsigned &align) {
79 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
85 llvm::LLVMContext llvmContext;
104 MemRefType memRefType,
Value llvmMemref,
Value base,
105 Value index, VectorType vectorType) {
107 "unsupported memref type");
108 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
112 vectorType.getScalableDims()[0]);
113 return rewriter.
create<LLVM::GEPOp>(
114 loc, ptrsType, typeConverter.
convertType(memRefType.getElementType()),
122 if (
auto attr = foldResult.dyn_cast<
Attribute>()) {
123 auto intAttr = cast<IntegerAttr>(attr);
124 return builder.
create<LLVM::ConstantOp>(loc, intAttr).getResult();
127 return foldResult.get<
Value>();
133 using VectorScaleOpConversion =
137 class VectorBitCastOpConversion
143 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
146 VectorType resultTy = bitCastOp.getResultVectorType();
147 if (resultTy.getRank() > 1)
149 Type newResultTy = typeConverter->convertType(resultTy);
151 adaptor.getOperands()[0]);
158 class VectorMatmulOpConversion
164 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
167 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
168 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
169 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
176 class VectorFlatTransposeOpConversion
182 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
185 transOp, typeConverter->convertType(transOp.getRes().getType()),
186 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
194 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
195 vector::LoadOpAdaptor adaptor,
196 VectorType vectorTy,
Value ptr,
unsigned align,
200 loadOp.getNontemporal());
203 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
204 vector::MaskedLoadOpAdaptor adaptor,
205 VectorType vectorTy,
Value ptr,
unsigned align,
208 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
211 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
212 vector::StoreOpAdaptor adaptor,
213 VectorType vectorTy,
Value ptr,
unsigned align,
217 storeOp.getNontemporal());
220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
221 vector::MaskedStoreOpAdaptor adaptor,
222 VectorType vectorTy,
Value ptr,
unsigned align,
225 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
230 template <
class LoadOrStoreOp>
236 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
237 typename LoadOrStoreOp::Adaptor adaptor,
240 VectorType vectorTy = loadOrStoreOp.getVectorType();
241 if (vectorTy.getRank() > 1)
244 auto loc = loadOrStoreOp->getLoc();
245 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
253 auto vtype = cast<VectorType>(
254 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
255 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
256 adaptor.getIndices(), rewriter);
257 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
264 class VectorGatherOpConversion
270 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
272 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
273 assert(memRefType &&
"The base should be bufferized");
278 auto loc = gather->getLoc();
285 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
286 adaptor.getIndices(), rewriter);
287 Value base = adaptor.getBase();
289 auto llvmNDVectorTy = adaptor.getIndexVec().
getType();
291 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
292 auto vType = gather.getVectorType();
295 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
296 base, ptr, adaptor.getIndexVec(), vType);
299 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
305 auto callback = [align, memRefType, base, ptr, loc, &rewriter,
306 &typeConverter](
Type llvm1DVectorTy,
310 rewriter, loc, typeConverter, memRefType, base, ptr,
311 vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
313 return rewriter.create<LLVM::masked_gather>(
314 loc, llvm1DVectorTy, ptrs, vectorOperands[1],
315 vectorOperands[2], rewriter.getI32IntegerAttr(align));
318 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
320 gather, vectorOperands, *getTypeConverter(), callback, rewriter);
325 class VectorScatterOpConversion
331 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
333 auto loc = scatter->getLoc();
334 MemRefType memRefType = scatter.getMemRefType();
345 VectorType vType = scatter.getVectorType();
346 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
347 adaptor.getIndices(), rewriter);
349 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
350 adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
354 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
361 class VectorExpandLoadOpConversion
367 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
369 auto loc = expand->getLoc();
370 MemRefType memRefType = expand.getMemRefType();
373 auto vtype = typeConverter->
convertType(expand.getVectorType());
374 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
375 adaptor.getIndices(), rewriter);
378 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
384 class VectorCompressStoreOpConversion
390 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
392 auto loc = compress->getLoc();
393 MemRefType memRefType = compress.getMemRefType();
396 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397 adaptor.getIndices(), rewriter);
400 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
406 class ReductionNeutralZero {};
407 class ReductionNeutralIntOne {};
408 class ReductionNeutralFPOne {};
409 class ReductionNeutralAllOnes {};
410 class ReductionNeutralSIntMin {};
411 class ReductionNeutralUIntMin {};
412 class ReductionNeutralSIntMax {};
413 class ReductionNeutralUIntMax {};
414 class ReductionNeutralFPMin {};
415 class ReductionNeutralFPMax {};
418 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
421 return rewriter.
create<LLVM::ConstantOp>(loc, llvmType,
426 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
429 return rewriter.
create<LLVM::ConstantOp>(
434 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
437 return rewriter.
create<LLVM::ConstantOp>(
442 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
445 return rewriter.
create<LLVM::ConstantOp>(
452 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
455 return rewriter.
create<LLVM::ConstantOp>(
462 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
465 return rewriter.
create<LLVM::ConstantOp>(
472 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
475 return rewriter.
create<LLVM::ConstantOp>(
482 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
485 return rewriter.
create<LLVM::ConstantOp>(
492 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
495 auto floatType = cast<FloatType>(llvmType);
496 return rewriter.
create<LLVM::ConstantOp>(
499 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
504 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
507 auto floatType = cast<FloatType>(llvmType);
508 return rewriter.
create<LLVM::ConstantOp>(
511 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
517 template <
class ReductionNeutral>
524 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
533 VectorType vType = cast<VectorType>(llvmType);
534 auto vShape = vType.getShape();
535 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
537 Value baseVecLength = rewriter.
create<LLVM::ConstantOp>(
541 if (!vType.getScalableDims()[0])
542 return baseVecLength;
545 Value vScale = rewriter.
create<vector::VectorScaleOp>(loc);
548 Value scalableVecLength =
549 rewriter.
create<arith::MulIOp>(loc, baseVecLength, vScale);
550 return scalableVecLength;
557 template <
class LLVMRedIntrinOp,
class ScalarOp>
558 static Value createIntegerReductionArithmeticOpLowering(
562 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
565 result = rewriter.
create<ScalarOp>(loc, accumulator, result);
573 template <
class LLVMRedIntrinOp>
574 static Value createIntegerReductionComparisonOpLowering(
576 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
577 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
580 rewriter.
create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
581 result = rewriter.
create<LLVM::SelectOp>(loc, cmp, accumulator, result);
587 template <
typename Source>
588 struct VectorToScalarMapper;
590 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
591 using Type = LLVM::MaximumOp;
594 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
595 using Type = LLVM::MinimumOp;
598 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
599 using Type = LLVM::MaxNumOp;
602 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
603 using Type = LLVM::MinNumOp;
607 template <
class LLVMRedIntrinOp>
608 static Value createFPReductionComparisonOpLowering(
610 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
612 rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
616 rewriter.
create<
typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
617 loc, result, accumulator);
624 class MaskNeutralFMaximum {};
625 class MaskNeutralFMinimum {};
629 getMaskNeutralValue(MaskNeutralFMaximum,
630 const llvm::fltSemantics &floatSemantics) {
631 return llvm::APFloat::getSmallest(floatSemantics,
true);
635 getMaskNeutralValue(MaskNeutralFMinimum,
636 const llvm::fltSemantics &floatSemantics) {
637 return llvm::APFloat::getLargest(floatSemantics,
false);
641 template <
typename MaskNeutral>
645 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
646 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
648 return rewriter.
create<LLVM::ConstantOp>(loc, vectorType, denseValue);
655 template <
class LLVMRedIntrinOp,
class MaskNeutral>
660 Value mask, LLVM::FastmathFlagsAttr fmf) {
661 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
662 rewriter, loc, llvmType, vectorOperand.
getType());
663 const Value selectedVectorByMask = rewriter.
create<LLVM::SelectOp>(
664 loc, mask, vectorOperand, vectorMaskNeutral);
665 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
666 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
669 template <
class LLVMRedIntrinOp,
class ReductionNeutral>
673 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
674 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
675 llvmType, accumulator);
676 return rewriter.
create<LLVMRedIntrinOp>(loc, llvmType,
684 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
689 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
690 llvmType, accumulator);
691 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
696 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
697 static Value lowerPredicatedReductionWithStartValue(
700 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
701 llvmType, accumulator);
703 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
704 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
706 vectorOperand, mask, vectorLength);
709 template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
710 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
711 static Value lowerPredicatedReductionWithStartValue(
715 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
716 IntReductionNeutral>(
717 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
720 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
722 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
726 class VectorReductionOpConversion
730 bool reassociateFPRed)
732 reassociateFPReductions(reassociateFPRed) {}
735 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
737 auto kind = reductionOp.getKind();
738 Type eltType = reductionOp.getDest().getType();
740 Value operand = adaptor.getVector();
741 Value acc = adaptor.getAcc();
742 Location loc = reductionOp.getLoc();
748 case vector::CombiningKind::ADD:
750 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
752 rewriter, loc, llvmType, operand, acc);
754 case vector::CombiningKind::MUL:
756 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
758 rewriter, loc, llvmType, operand, acc);
761 result = createIntegerReductionComparisonOpLowering<
762 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
763 LLVM::ICmpPredicate::ule);
765 case vector::CombiningKind::MINSI:
766 result = createIntegerReductionComparisonOpLowering<
767 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
768 LLVM::ICmpPredicate::sle);
770 case vector::CombiningKind::MAXUI:
771 result = createIntegerReductionComparisonOpLowering<
772 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
773 LLVM::ICmpPredicate::uge);
775 case vector::CombiningKind::MAXSI:
776 result = createIntegerReductionComparisonOpLowering<
777 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
778 LLVM::ICmpPredicate::sge);
780 case vector::CombiningKind::AND:
782 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
784 rewriter, loc, llvmType, operand, acc);
786 case vector::CombiningKind::OR:
788 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
790 rewriter, loc, llvmType, operand, acc);
792 case vector::CombiningKind::XOR:
794 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
796 rewriter, loc, llvmType, operand, acc);
806 if (!isa<FloatType>(eltType))
809 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
811 reductionOp.getContext(),
814 reductionOp.getContext(),
815 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
816 : LLVM::FastmathFlags::none));
820 if (kind == vector::CombiningKind::ADD) {
821 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
822 ReductionNeutralZero>(
823 rewriter, loc, llvmType, operand, acc, fmf);
824 }
else if (kind == vector::CombiningKind::MUL) {
825 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
826 ReductionNeutralFPOne>(
827 rewriter, loc, llvmType, operand, acc, fmf);
828 }
else if (kind == vector::CombiningKind::MINIMUMF) {
830 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
831 rewriter, loc, llvmType, operand, acc, fmf);
832 }
else if (kind == vector::CombiningKind::MAXIMUMF) {
834 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
835 rewriter, loc, llvmType, operand, acc, fmf);
836 }
else if (kind == vector::CombiningKind::MINNUMF) {
837 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
838 rewriter, loc, llvmType, operand, acc, fmf);
839 }
else if (kind == vector::CombiningKind::MAXNUMF) {
840 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
841 rewriter, loc, llvmType, operand, acc, fmf);
850 const bool reassociateFPReductions;
861 template <
class MaskedOp>
862 class VectorMaskOpConversionBase
868 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
871 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
874 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
878 virtual LogicalResult
879 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
880 vector::MaskableOpInterface maskableOp,
884 class MaskedReductionOpConversion
885 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
888 using VectorMaskOpConversionBase<
889 vector::ReductionOp>::VectorMaskOpConversionBase;
891 LogicalResult matchAndRewriteMaskableOp(
892 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
894 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
895 auto kind = reductionOp.getKind();
896 Type eltType = reductionOp.getDest().getType();
898 Value operand = reductionOp.getVector();
899 Value acc = reductionOp.getAcc();
900 Location loc = reductionOp.getLoc();
902 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
904 reductionOp.getContext(),
909 case vector::CombiningKind::ADD:
910 result = lowerPredicatedReductionWithStartValue<
911 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
912 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
915 case vector::CombiningKind::MUL:
916 result = lowerPredicatedReductionWithStartValue<
917 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
918 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
922 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
923 ReductionNeutralUIntMax>(
924 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
926 case vector::CombiningKind::MINSI:
927 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
928 ReductionNeutralSIntMax>(
929 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
931 case vector::CombiningKind::MAXUI:
932 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
933 ReductionNeutralUIntMin>(
934 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
936 case vector::CombiningKind::MAXSI:
937 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
938 ReductionNeutralSIntMin>(
939 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
941 case vector::CombiningKind::AND:
942 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
943 ReductionNeutralAllOnes>(
944 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
946 case vector::CombiningKind::OR:
947 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
948 ReductionNeutralZero>(
949 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
951 case vector::CombiningKind::XOR:
952 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
953 ReductionNeutralZero>(
954 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
956 case vector::CombiningKind::MINNUMF:
957 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
958 ReductionNeutralFPMax>(
959 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
961 case vector::CombiningKind::MAXNUMF:
962 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
963 ReductionNeutralFPMin>(
964 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
966 case CombiningKind::MAXIMUMF:
967 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
968 MaskNeutralFMaximum>(
969 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
971 case CombiningKind::MINIMUMF:
972 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
973 MaskNeutralFMinimum>(
974 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
979 rewriter.replaceOp(maskOp, result);
984 class VectorShuffleOpConversion
990 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
992 auto loc = shuffleOp->getLoc();
993 auto v1Type = shuffleOp.getV1VectorType();
994 auto v2Type = shuffleOp.getV2VectorType();
995 auto vectorType = shuffleOp.getResultVectorType();
1004 int64_t rank = vectorType.getRank();
1006 bool wellFormed0DCase =
1007 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1008 bool wellFormedNDCase =
1009 v1Type.getRank() == rank && v2Type.getRank() == rank;
1010 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1015 if (rank <= 1 && v1Type == v2Type) {
1016 Value llvmShuffleOp = rewriter.
create<LLVM::ShuffleVectorOp>(
1017 loc, adaptor.getV1(), adaptor.getV2(),
1018 llvm::to_vector_of<int32_t>(mask));
1019 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
1024 int64_t v1Dim = v1Type.getDimSize(0);
1026 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1027 eltType = arrayType.getElementType();
1029 eltType = cast<VectorType>(llvmType).getElementType();
1030 Value insert = rewriter.
create<LLVM::UndefOp>(loc, llvmType);
1032 for (int64_t extPos : mask) {
1033 Value value = adaptor.getV1();
1034 if (extPos >= v1Dim) {
1036 value = adaptor.getV2();
1039 eltType, rank, extPos);
1040 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1041 llvmType, rank, insPos++);
1048 class VectorExtractElementOpConversion
1055 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1057 auto vectorType = extractEltOp.getSourceVectorType();
1058 auto llvmType = typeConverter->
convertType(vectorType.getElementType());
1064 if (vectorType.getRank() == 0) {
1065 Location loc = extractEltOp.getLoc();
1067 auto zero = rewriter.
create<LLVM::ConstantOp>(
1071 extractEltOp, llvmType, adaptor.getVector(), zero);
1076 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1081 class VectorExtractOpConversion
1087 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1089 auto loc = extractOp->getLoc();
1090 auto resultType = extractOp.getResult().getType();
1091 auto llvmResultType = typeConverter->
convertType(resultType);
1093 if (!llvmResultType)
1097 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1111 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1115 bool extractsScalar =
static_cast<int64_t
>(positionVec.size()) ==
1116 extractOp.getSourceVectorType().getRank();
1120 if (extractOp.getSourceVectorType().getRank() == 0) {
1122 positionVec.push_back(rewriter.
getZeroAttr(idxType));
1125 Value extracted = adaptor.getVector();
1126 if (extractsAggregate) {
1128 if (extractsScalar) {
1132 position = position.drop_back();
1135 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1138 extracted = rewriter.
create<LLVM::ExtractValueOp>(
1142 if (extractsScalar) {
1143 extracted = rewriter.
create<LLVM::ExtractElementOp>(
1144 loc, extracted,
getAsLLVMValue(rewriter, loc, positionVec.back()));
1147 rewriter.
replaceOp(extractOp, extracted);
1171 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1173 VectorType vType = fmaOp.getVectorType();
1174 if (vType.getRank() > 1)
1178 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1183 class VectorInsertElementOpConversion
1189 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1191 auto vectorType = insertEltOp.getDestVectorType();
1192 auto llvmType = typeConverter->
convertType(vectorType);
1198 if (vectorType.getRank() == 0) {
1199 Location loc = insertEltOp.getLoc();
1201 auto zero = rewriter.
create<LLVM::ConstantOp>(
1205 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1210 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1211 adaptor.getPosition());
1216 class VectorInsertOpConversion
1222 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1224 auto loc = insertOp->getLoc();
1225 auto sourceType = insertOp.getSourceType();
1226 auto destVectorType = insertOp.getDestVectorType();
1227 auto llvmResultType = typeConverter->
convertType(destVectorType);
1229 if (!llvmResultType)
1233 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1238 if (position.empty()) {
1239 rewriter.
replaceOp(insertOp, adaptor.getSource());
1244 if (isa<VectorType>(sourceType)) {
1245 if (insertOp.hasDynamicPosition())
1248 Value inserted = rewriter.
create<LLVM::InsertValueOp>(
1249 loc, adaptor.getDest(), adaptor.getSource(),
getAsIntegers(position));
1255 Value extracted = adaptor.getDest();
1256 auto oneDVectorType = destVectorType;
1257 if (position.size() > 1) {
1258 if (insertOp.hasDynamicPosition())
1262 extracted = rewriter.
create<LLVM::ExtractValueOp>(
1267 Value inserted = rewriter.
create<LLVM::InsertElementOp>(
1268 loc, typeConverter->
convertType(oneDVectorType), extracted,
1269 adaptor.getSource(),
getAsLLVMValue(rewriter, loc, position.back()));
1272 if (position.size() > 1) {
1273 if (insertOp.hasDynamicPosition())
1276 inserted = rewriter.
create<LLVM::InsertValueOp>(
1277 loc, adaptor.getDest(), inserted,
1287 struct VectorScalableInsertOpLowering
1293 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1296 insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1302 struct VectorScalableExtractOpLowering
1308 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1311 extOp, typeConverter->
convertType(extOp.getResultVectorType()),
1312 adaptor.getSource(), adaptor.getPos());
1345 setHasBoundedRewriteRecursion();
1348 LogicalResult matchAndRewrite(FMAOp op,
1350 auto vType = op.getVectorType();
1351 if (vType.getRank() < 2)
1354 auto loc = op.getLoc();
1355 auto elemType = vType.getElementType();
1358 Value desc = rewriter.
create<vector::SplatOp>(loc, vType, zero);
1359 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1360 Value extrLHS = rewriter.
create<ExtractOp>(loc, op.getLhs(), i);
1361 Value extrRHS = rewriter.
create<ExtractOp>(loc, op.getRhs(), i);
1362 Value extrACC = rewriter.
create<ExtractOp>(loc, op.getAcc(), i);
1363 Value fma = rewriter.
create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1364 desc = rewriter.
create<InsertOp>(loc, fma, desc, i);
1373 static std::optional<SmallVector<int64_t, 4>>
1374 computeContiguousStrides(MemRefType memRefType) {
1378 return std::nullopt;
1379 if (!strides.empty() && strides.back() != 1)
1380 return std::nullopt;
1382 if (memRefType.getLayout().isIdentity())
1389 auto sizes = memRefType.getShape();
1390 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
1391 if (ShapedType::isDynamic(sizes[index + 1]) ||
1392 ShapedType::isDynamic(strides[index]) ||
1393 ShapedType::isDynamic(strides[index + 1]))
1394 return std::nullopt;
1395 if (strides[index] != strides[index + 1] * sizes[index + 1])
1396 return std::nullopt;
1401 class VectorTypeCastOpConversion
1407 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1409 auto loc = castOp->getLoc();
1410 MemRefType sourceMemRefType =
1411 cast<MemRefType>(castOp.getOperand().getType());
1412 MemRefType targetMemRefType = castOp.getType();
1415 if (!sourceMemRefType.hasStaticShape() ||
1416 !targetMemRefType.hasStaticShape())
1419 auto llvmSourceDescriptorTy =
1420 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1421 if (!llvmSourceDescriptorTy)
1425 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1427 if (!llvmTargetDescriptorTy)
1431 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1434 auto targetStrides = computeContiguousStrides(targetMemRefType);
1438 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1446 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1447 desc.setAllocatedPtr(rewriter, loc, allocated);
1450 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1451 desc.setAlignedPtr(rewriter, loc, ptr);
1454 auto zero = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, attr);
1455 desc.setOffset(rewriter, loc, zero);
1458 for (
const auto &indexedSize :
1460 int64_t index = indexedSize.index();
1463 auto size = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1464 desc.setSize(rewriter, loc, index, size);
1466 (*targetStrides)[index]);
1467 auto stride = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1468 desc.setStride(rewriter, loc, index, stride);
1478 class VectorCreateMaskOpConversion
1481 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1482 bool enableIndexOpt)
1484 force32BitVectorIndices(enableIndexOpt) {}
1487 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1489 auto dstType = op.getType();
1490 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1492 IntegerType idxType =
1494 auto loc = op->getLoc();
1495 Value indices = rewriter.
create<LLVM::StepVectorOp>(
1499 adaptor.getOperands()[0]);
1501 Value comp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1508 const bool force32BitVectorIndices;
1529 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1531 auto parent =
printOp->getParentOfType<ModuleOp>();
1537 if (
auto value = adaptor.getSource()) {
1543 if (failed(emitScalarPrint(rewriter, parent, loc,
printType, value)))
1547 auto punct =
printOp.getPunctuation();
1548 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1550 *stringLiteral, *getTypeConverter(),
1552 }
else if (punct != PrintPunctuation::NoPunctuation) {
1553 emitCall(rewriter,
printOp->getLoc(), [&] {
1555 case PrintPunctuation::Close:
1556 return LLVM::lookupOrCreatePrintCloseFn(parent);
1557 case PrintPunctuation::Open:
1558 return LLVM::lookupOrCreatePrintOpenFn(parent);
1559 case PrintPunctuation::Comma:
1560 return LLVM::lookupOrCreatePrintCommaFn(parent);
1561 case PrintPunctuation::NewLine:
1562 return LLVM::lookupOrCreatePrintNewlineFn(parent);
1564 llvm_unreachable(
"unexpected punctuation");
1574 enum class PrintConversion {
1585 Value value)
const {
1597 conversion = PrintConversion::Bitcast16;
1600 conversion = PrintConversion::Bitcast16;
1604 }
else if (
auto intTy = dyn_cast<IntegerType>(
printType)) {
1608 unsigned width = intTy.getWidth();
1609 if (intTy.isUnsigned()) {
1612 conversion = PrintConversion::ZeroExt64;
1618 assert(intTy.isSignless() || intTy.isSigned());
1623 conversion = PrintConversion::ZeroExt64;
1624 else if (width < 64)
1625 conversion = PrintConversion::SignExt64;
1635 switch (conversion) {
1636 case PrintConversion::ZeroExt64:
1637 value = rewriter.
create<arith::ExtUIOp>(
1640 case PrintConversion::SignExt64:
1641 value = rewriter.
create<arith::ExtSIOp>(
1644 case PrintConversion::Bitcast16:
1645 value = rewriter.
create<LLVM::BitcastOp>(
1651 emitCall(rewriter, loc, printer, value);
1669 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1671 VectorType resultType = cast<VectorType>(splatOp.getType());
1672 if (resultType.getRank() > 1)
1676 auto vectorType = typeConverter->
convertType(splatOp.getType());
1677 Value undef = rewriter.
create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1678 auto zero = rewriter.
create<LLVM::ConstantOp>(
1684 if (resultType.getRank() == 0) {
1686 splatOp, vectorType, undef, adaptor.getInput(), zero);
1691 auto v = rewriter.
create<LLVM::InsertElementOp>(
1692 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1694 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1711 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1713 VectorType resultType = splatOp.getType();
1714 if (resultType.getRank() <= 1)
1718 auto loc = splatOp.getLoc();
1719 auto vectorTypeInfo =
1721 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1722 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1723 if (!llvmNDVectorTy || !llvm1DVectorTy)
1727 Value desc = rewriter.
create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1731 Value vdesc = rewriter.
create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1732 auto zero = rewriter.
create<LLVM::ConstantOp>(
1735 Value v = rewriter.
create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1736 adaptor.getInput(), zero);
1739 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1741 v = rewriter.
create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1746 desc = rewriter.
create<LLVM::InsertValueOp>(loc, desc, v, position);
1755 struct VectorInterleaveOpLowering
1760 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1762 VectorType resultType = interleaveOp.getResultVectorType();
1764 if (resultType.getRank() != 1)
1766 "InterleaveOp not rank 1");
1768 if (resultType.isScalable()) {
1770 interleaveOp, typeConverter->
convertType(resultType),
1771 adaptor.getLhs(), adaptor.getRhs());
1778 int64_t resultVectorSize = resultType.getNumElements();
1780 interleaveShuffleMask.reserve(resultVectorSize);
1781 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1782 interleaveShuffleMask.push_back(i);
1783 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1786 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1787 interleaveShuffleMask);
1794 struct VectorDeinterleaveOpLowering
1799 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1801 VectorType resultType = deinterleaveOp.getResultVectorType();
1802 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1803 auto loc = deinterleaveOp.getLoc();
1807 if (resultType.getRank() != 1)
1809 "DeinterleaveOp not rank 1");
1811 if (resultType.isScalable()) {
1812 auto llvmTypeConverter = this->getTypeConverter();
1813 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1814 auto packedOpResults =
1815 llvmTypeConverter->packOperationResults(deinterleaveResults);
1816 auto intrinsic = rewriter.
create<LLVM::vector_deinterleave2>(
1817 loc, packedOpResults, adaptor.getSource());
1819 auto evenResult = rewriter.
create<LLVM::ExtractValueOp>(
1820 loc, intrinsic->getResult(0), 0);
1821 auto oddResult = rewriter.
create<LLVM::ExtractValueOp>(
1822 loc, intrinsic->getResult(0), 1);
1831 int64_t resultVectorSize = resultType.getNumElements();
1835 evenShuffleMask.reserve(resultVectorSize);
1836 oddShuffleMask.reserve(resultVectorSize);
1838 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1840 evenShuffleMask.push_back(i);
1842 oddShuffleMask.push_back(i);
1845 auto poison = rewriter.
create<LLVM::PoisonOp>(loc, sourceType);
1846 auto evenShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1847 loc, adaptor.getSource(), poison, evenShuffleMask);
1848 auto oddShuffle = rewriter.
create<LLVM::ShuffleVectorOp>(
1849 loc, adaptor.getSource(), poison, oddShuffleMask);
1857 struct VectorFromElementsLowering
1862 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1864 Location loc = fromElementsOp.getLoc();
1865 VectorType vectorType = fromElementsOp.getType();
1868 if (vectorType.getRank() > 1)
1870 "rank > 1 vectors are not supported");
1872 Value result = rewriter.
create<LLVM::UndefOp>(loc, llvmType);
1874 result = rewriter.
create<vector::InsertOp>(loc, val, result, idx);
1875 rewriter.
replaceOp(fromElementsOp, result);
1881 struct VectorScalableStepOpLowering
1886 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1888 auto resultType = cast<VectorType>(stepOp.getType());
1889 if (!resultType.isScalable()) {
1908 bool reassociateFPReductions,
bool force32BitVectorIndices) {
1911 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1912 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
1913 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1914 VectorExtractElementOpConversion, VectorExtractOpConversion,
1915 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1916 VectorInsertOpConversion, VectorPrintOpConversion,
1917 VectorTypeCastOpConversion, VectorScaleOpConversion,
1918 VectorLoadStoreConversion<vector::LoadOp>,
1919 VectorLoadStoreConversion<vector::MaskedLoadOp>,
1920 VectorLoadStoreConversion<vector::StoreOp>,
1921 VectorLoadStoreConversion<vector::MaskedStoreOp>,
1922 VectorGatherOpConversion, VectorScatterOpConversion,
1923 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1924 VectorSplatOpLowering, VectorSplatNdOpLowering,
1925 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1926 MaskedReductionOpConversion, VectorInterleaveOpLowering,
1927 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1928 VectorScalableStepOpLowering>(converter);
1933 patterns.add<VectorMatmulOpConversion>(converter);
1934 patterns.add<VectorFlatTransposeOpConversion>(converter);
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
static VectorType reducedVectorTypeBack(VectorType tp)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls.
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp)
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMMatrixConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...