31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/Support/Casting.h"
45 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
48 auto constant = LLVM::ConstantOp::create(
51 return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
54 return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
60 Value val,
Type llvmType, int64_t rank, int64_t pos) {
63 auto constant = LLVM::ConstantOp::create(
66 return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
69 return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
74 VectorType vectorType,
unsigned &align) {
76 if (!convertedVectorTy)
79 llvm::LLVMContext llvmContext;
80 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
81 .getPreferredAlignment(convertedVectorTy,
89 MemRefType memrefType,
unsigned &align) {
90 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
96 llvm::LLVMContext llvmContext;
97 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
98 .getPreferredAlignment(elementTy, typeConverter.
getDataLayout());
109 VectorType vectorType,
110 MemRefType memrefType,
unsigned &align,
111 bool useVectorAlignment) {
112 if (useVectorAlignment) {
127 if (!memRefType.isLastDimUnitStride())
137 MemRefType memRefType,
Value llvmMemref,
Value base,
138 Value index, VectorType vectorType) {
140 "unsupported memref type");
141 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
145 vectorType.getScalableDims()[0]);
146 return LLVM::GEPOp::create(
147 rewriter, loc, ptrsType,
148 typeConverter.
convertType(memRefType.getElementType()), base, index);
155 if (
auto attr = dyn_cast<Attribute>(foldResult)) {
156 auto intAttr = cast<IntegerAttr>(attr);
157 return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
160 return cast<Value>(foldResult);
166 using VectorScaleOpConversion =
170 class VectorBitCastOpConversion
176 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
179 VectorType resultTy = bitCastOp.getResultVectorType();
180 if (resultTy.getRank() > 1)
182 Type newResultTy = typeConverter->convertType(resultTy);
184 adaptor.getOperands()[0]);
192 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193 vector::LoadOpAdaptor adaptor,
194 VectorType vectorTy,
Value ptr,
unsigned align,
198 loadOp.getNontemporal());
201 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202 vector::MaskedLoadOpAdaptor adaptor,
203 VectorType vectorTy,
Value ptr,
unsigned align,
206 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
209 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210 vector::StoreOpAdaptor adaptor,
211 VectorType vectorTy,
Value ptr,
unsigned align,
215 storeOp.getNontemporal());
218 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219 vector::MaskedStoreOpAdaptor adaptor,
220 VectorType vectorTy,
Value ptr,
unsigned align,
223 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
228 template <
class LoadOrStoreOp>
234 useVectorAlignment(useVectorAlign) {}
238 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239 typename LoadOrStoreOp::Adaptor adaptor,
242 VectorType vectorTy = loadOrStoreOp.getVectorType();
243 if (vectorTy.getRank() > 1)
246 auto loc = loadOrStoreOp->getLoc();
247 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
251 unsigned align = loadOrStoreOp.getAlignment().value_or(0);
254 memRefTy, align, useVectorAlignment)))
256 "could not resolve alignment");
259 auto vtype = cast<VectorType>(
260 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
262 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
263 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
273 const bool useVectorAlignment;
277 class VectorGatherOpConversion
283 useVectorAlignment(useVectorAlign) {}
287 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
290 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
291 assert(memRefType &&
"The base should be bufferized");
296 VectorType vType = gather.getVectorType();
297 if (vType.getRank() > 1) {
299 gather,
"only 1-D vectors can be lowered to LLVM");
304 unsigned align = gather.getAlignment().value_or(0);
307 memRefType, align, useVectorAlignment)))
312 adaptor.getBase(), adaptor.getOffsets());
313 Value base = adaptor.getBase();
315 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
316 base, ptr, adaptor.getIndices(), vType);
320 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
330 const bool useVectorAlignment;
334 class VectorScatterOpConversion
340 useVectorAlignment(useVectorAlign) {}
345 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
347 auto loc = scatter->getLoc();
348 MemRefType memRefType = scatter.getMemRefType();
353 VectorType vType = scatter.getVectorType();
354 if (vType.getRank() > 1) {
356 scatter,
"only 1-D vectors can be lowered to LLVM");
361 unsigned align = scatter.getAlignment().value_or(0);
364 memRefType, align, useVectorAlignment)))
366 "could not resolve alignment");
370 adaptor.getBase(), adaptor.getOffsets());
372 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
373 adaptor.getBase(), ptr, adaptor.getIndices(), vType);
377 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
387 const bool useVectorAlignment;
391 class VectorExpandLoadOpConversion
397 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
399 auto loc = expand->getLoc();
400 MemRefType memRefType = expand.getMemRefType();
403 auto vtype = typeConverter->convertType(expand.getVectorType());
405 adaptor.getBase(), adaptor.getIndices());
410 uint64_t alignment = expand.getAlignment().value_or(1);
413 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
420 class VectorCompressStoreOpConversion
426 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
428 auto loc = compress->getLoc();
429 MemRefType memRefType = compress.getMemRefType();
433 adaptor.getBase(), adaptor.getIndices());
438 uint64_t alignment = compress.getAlignment().value_or(1);
441 compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
447 class ReductionNeutralZero {};
448 class ReductionNeutralIntOne {};
449 class ReductionNeutralFPOne {};
450 class ReductionNeutralAllOnes {};
451 class ReductionNeutralSIntMin {};
452 class ReductionNeutralUIntMin {};
453 class ReductionNeutralSIntMax {};
454 class ReductionNeutralUIntMax {};
455 class ReductionNeutralFPMin {};
456 class ReductionNeutralFPMax {};
459 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
462 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
467 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
470 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
475 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
478 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
483 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
486 return LLVM::ConstantOp::create(
487 rewriter, loc, llvmType,
493 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
496 return LLVM::ConstantOp::create(
497 rewriter, loc, llvmType,
503 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
506 return LLVM::ConstantOp::create(
507 rewriter, loc, llvmType,
513 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
516 return LLVM::ConstantOp::create(
517 rewriter, loc, llvmType,
523 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
526 return LLVM::ConstantOp::create(
527 rewriter, loc, llvmType,
533 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
536 auto floatType = cast<FloatType>(llvmType);
537 return LLVM::ConstantOp::create(
538 rewriter, loc, llvmType,
540 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
545 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
548 auto floatType = cast<FloatType>(llvmType);
549 return LLVM::ConstantOp::create(
550 rewriter, loc, llvmType,
552 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
558 template <
class ReductionNeutral>
565 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
574 VectorType vType = cast<VectorType>(llvmType);
575 auto vShape = vType.getShape();
576 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
578 Value baseVecLength = LLVM::ConstantOp::create(
582 if (!vType.getScalableDims()[0])
583 return baseVecLength;
586 Value vScale = vector::VectorScaleOp::create(rewriter, loc);
588 arith::IndexCastOp::create(rewriter, loc, rewriter.
getI32Type(), vScale);
589 Value scalableVecLength =
590 arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
591 return scalableVecLength;
598 template <
class LLVMRedIntrinOp,
class ScalarOp>
599 static Value createIntegerReductionArithmeticOpLowering(
604 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
607 result = ScalarOp::create(rewriter, loc, accumulator, result);
615 template <
class LLVMRedIntrinOp>
616 static Value createIntegerReductionComparisonOpLowering(
618 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
620 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
623 LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result);
624 result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result);
630 template <
typename Source>
631 struct VectorToScalarMapper;
633 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
634 using Type = LLVM::MaximumOp;
637 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
638 using Type = LLVM::MinimumOp;
641 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
642 using Type = LLVM::MaxNumOp;
645 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
646 using Type = LLVM::MinNumOp;
650 template <
class LLVMRedIntrinOp>
651 static Value createFPReductionComparisonOpLowering(
653 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
655 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
658 result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
659 rewriter, loc, result, accumulator);
666 class MaskNeutralFMaximum {};
667 class MaskNeutralFMinimum {};
671 getMaskNeutralValue(MaskNeutralFMaximum,
672 const llvm::fltSemantics &floatSemantics) {
673 return llvm::APFloat::getSmallest(floatSemantics,
true);
677 getMaskNeutralValue(MaskNeutralFMinimum,
678 const llvm::fltSemantics &floatSemantics) {
679 return llvm::APFloat::getLargest(floatSemantics,
false);
683 template <
typename MaskNeutral>
687 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
688 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
690 return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
697 template <
class LLVMRedIntrinOp,
class MaskNeutral>
702 Value mask, LLVM::FastmathFlagsAttr fmf) {
703 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
704 rewriter, loc, llvmType, vectorOperand.
getType());
705 const Value selectedVectorByMask = LLVM::SelectOp::create(
706 rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
707 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
708 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
711 template <
class LLVMRedIntrinOp,
class ReductionNeutral>
715 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
716 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
717 llvmType, accumulator);
718 return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
719 accumulator, vectorOperand,
726 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
731 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
732 llvmType, accumulator);
733 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
734 accumulator, vectorOperand);
737 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
738 static Value lowerPredicatedReductionWithStartValue(
741 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
742 llvmType, accumulator);
744 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
745 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
746 accumulator, vectorOperand,
750 template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
751 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
752 static Value lowerPredicatedReductionWithStartValue(
756 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
757 IntReductionNeutral>(
758 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
761 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
763 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
767 class VectorReductionOpConversion
771 bool reassociateFPRed)
773 reassociateFPReductions(reassociateFPRed) {}
776 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
778 auto kind = reductionOp.getKind();
779 Type eltType = reductionOp.getDest().getType();
780 Type llvmType = typeConverter->convertType(eltType);
781 Value operand = adaptor.getVector();
782 Value acc = adaptor.getAcc();
783 Location loc = reductionOp.getLoc();
789 case vector::CombiningKind::ADD:
791 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
793 rewriter, loc, llvmType, operand, acc);
795 case vector::CombiningKind::MUL:
797 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
799 rewriter, loc, llvmType, operand, acc);
802 result = createIntegerReductionComparisonOpLowering<
803 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
804 LLVM::ICmpPredicate::ule);
806 case vector::CombiningKind::MINSI:
807 result = createIntegerReductionComparisonOpLowering<
808 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
809 LLVM::ICmpPredicate::sle);
811 case vector::CombiningKind::MAXUI:
812 result = createIntegerReductionComparisonOpLowering<
813 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
814 LLVM::ICmpPredicate::uge);
816 case vector::CombiningKind::MAXSI:
817 result = createIntegerReductionComparisonOpLowering<
818 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
819 LLVM::ICmpPredicate::sge);
821 case vector::CombiningKind::AND:
823 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
825 rewriter, loc, llvmType, operand, acc);
827 case vector::CombiningKind::OR:
829 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
831 rewriter, loc, llvmType, operand, acc);
833 case vector::CombiningKind::XOR:
835 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
837 rewriter, loc, llvmType, operand, acc);
847 if (!isa<FloatType>(eltType))
850 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
852 reductionOp.getContext(),
855 reductionOp.getContext(),
856 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
857 : LLVM::FastmathFlags::none));
861 if (
kind == vector::CombiningKind::ADD) {
862 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
863 ReductionNeutralZero>(
864 rewriter, loc, llvmType, operand, acc, fmf);
865 }
else if (
kind == vector::CombiningKind::MUL) {
866 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
867 ReductionNeutralFPOne>(
868 rewriter, loc, llvmType, operand, acc, fmf);
869 }
else if (
kind == vector::CombiningKind::MINIMUMF) {
871 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
872 rewriter, loc, llvmType, operand, acc, fmf);
873 }
else if (
kind == vector::CombiningKind::MAXIMUMF) {
875 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
876 rewriter, loc, llvmType, operand, acc, fmf);
877 }
else if (
kind == vector::CombiningKind::MINNUMF) {
878 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
879 rewriter, loc, llvmType, operand, acc, fmf);
880 }
else if (
kind == vector::CombiningKind::MAXNUMF) {
881 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
882 rewriter, loc, llvmType, operand, acc, fmf);
892 const bool reassociateFPReductions;
903 template <
class MaskedOp>
904 class VectorMaskOpConversionBase
910 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
913 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
916 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
920 virtual LogicalResult
921 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
922 vector::MaskableOpInterface maskableOp,
926 class MaskedReductionOpConversion
927 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
930 using VectorMaskOpConversionBase<
931 vector::ReductionOp>::VectorMaskOpConversionBase;
933 LogicalResult matchAndRewriteMaskableOp(
934 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
936 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
937 auto kind = reductionOp.getKind();
938 Type eltType = reductionOp.getDest().getType();
939 Type llvmType = typeConverter->convertType(eltType);
940 Value operand = reductionOp.getVector();
941 Value acc = reductionOp.getAcc();
942 Location loc = reductionOp.getLoc();
944 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
946 reductionOp.getContext(),
951 case vector::CombiningKind::ADD:
952 result = lowerPredicatedReductionWithStartValue<
953 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
954 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
957 case vector::CombiningKind::MUL:
958 result = lowerPredicatedReductionWithStartValue<
959 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
960 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
964 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
965 ReductionNeutralUIntMax>(
966 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
968 case vector::CombiningKind::MINSI:
969 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
970 ReductionNeutralSIntMax>(
971 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
973 case vector::CombiningKind::MAXUI:
974 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
975 ReductionNeutralUIntMin>(
976 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
978 case vector::CombiningKind::MAXSI:
979 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
980 ReductionNeutralSIntMin>(
981 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
983 case vector::CombiningKind::AND:
984 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
985 ReductionNeutralAllOnes>(
986 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
988 case vector::CombiningKind::OR:
989 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
990 ReductionNeutralZero>(
991 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
993 case vector::CombiningKind::XOR:
994 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
995 ReductionNeutralZero>(
996 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
998 case vector::CombiningKind::MINNUMF:
999 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1000 ReductionNeutralFPMax>(
1001 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1003 case vector::CombiningKind::MAXNUMF:
1004 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1005 ReductionNeutralFPMin>(
1006 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1008 case CombiningKind::MAXIMUMF:
1009 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1010 MaskNeutralFMaximum>(
1011 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1013 case CombiningKind::MINIMUMF:
1014 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1015 MaskNeutralFMinimum>(
1016 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1021 rewriter.replaceOp(maskOp, result);
1026 class VectorShuffleOpConversion
1032 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1034 auto loc = shuffleOp->getLoc();
1035 auto v1Type = shuffleOp.getV1VectorType();
1036 auto v2Type = shuffleOp.getV2VectorType();
1037 auto vectorType = shuffleOp.getResultVectorType();
1038 Type llvmType = typeConverter->convertType(vectorType);
1046 int64_t rank = vectorType.getRank();
1048 bool wellFormed0DCase =
1049 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1050 bool wellFormedNDCase =
1051 v1Type.getRank() == rank && v2Type.getRank() == rank;
1052 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1057 if (rank <= 1 && v1Type == v2Type) {
1058 Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1059 rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1060 llvm::to_vector_of<int32_t>(mask));
1061 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
1066 int64_t v1Dim = v1Type.getDimSize(0);
1068 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1069 eltType = arrayType.getElementType();
1071 eltType = cast<VectorType>(llvmType).getElementType();
1072 Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1074 for (int64_t extPos : mask) {
1075 Value value = adaptor.getV1();
1076 if (extPos >= v1Dim) {
1078 value = adaptor.getV2();
1081 eltType, rank, extPos);
1082 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1083 llvmType, rank, insPos++);
1090 class VectorExtractOpConversion
1096 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1098 auto loc = extractOp->getLoc();
1099 auto resultType = extractOp.getResult().getType();
1100 auto llvmResultType = typeConverter->convertType(resultType);
1102 if (!llvmResultType)
1106 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1120 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1124 bool extractsScalar =
static_cast<int64_t
>(positionVec.size()) ==
1125 extractOp.getSourceVectorType().getRank();
1129 if (extractOp.getSourceVectorType().getRank() == 0) {
1131 positionVec.push_back(rewriter.
getZeroAttr(idxType));
1134 Value extracted = adaptor.getSource();
1135 if (extractsAggregate) {
1137 if (extractsScalar) {
1141 position = position.drop_back();
1144 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1147 extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1151 if (extractsScalar) {
1152 extracted = LLVM::ExtractElementOp::create(
1153 rewriter, loc, extracted,
1157 rewriter.
replaceOp(extractOp, extracted);
1181 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1183 VectorType vType = fmaOp.getVectorType();
1184 if (vType.getRank() > 1)
1188 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1193 class VectorInsertOpConversion
1199 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1201 auto loc = insertOp->getLoc();
1202 auto destVectorType = insertOp.getDestVectorType();
1203 auto llvmResultType = typeConverter->convertType(destVectorType);
1205 if (!llvmResultType)
1209 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1231 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1233 bool insertIntoInnermostDim =
1234 static_cast<int64_t
>(positionVec.size()) == destVectorType.getRank();
1237 positionVec.begin(),
1238 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1240 if (destVectorType.getRank() == 0) {
1244 positionOfScalarWithin1DVector = rewriter.
getZeroAttr(idxType);
1245 }
else if (insertIntoInnermostDim) {
1246 positionOfScalarWithin1DVector = positionVec.back();
1252 Value sourceAggregate = adaptor.getValueToStore();
1253 if (insertIntoInnermostDim) {
1256 if (isNestedAggregate) {
1259 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1260 llvm::IsaPred<Attribute>)) {
1264 sourceAggregate = LLVM::ExtractValueOp::create(
1265 rewriter, loc, adaptor.getDest(),
1270 sourceAggregate = adaptor.getDest();
1273 sourceAggregate = LLVM::InsertElementOp::create(
1274 rewriter, loc, sourceAggregate.
getType(), sourceAggregate,
1275 adaptor.getValueToStore(),
1279 Value result = sourceAggregate;
1280 if (isNestedAggregate) {
1281 result = LLVM::InsertValueOp::create(
1282 rewriter, loc, adaptor.getDest(), sourceAggregate,
1292 struct VectorScalableInsertOpLowering
1298 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1301 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1307 struct VectorScalableExtractOpLowering
1313 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1316 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1317 adaptor.getSource(), adaptor.getPos());
1350 setHasBoundedRewriteRecursion();
1353 LogicalResult matchAndRewrite(FMAOp op,
1355 auto vType = op.getVectorType();
1356 if (vType.getRank() < 2)
1359 auto loc = op.getLoc();
1360 auto elemType = vType.getElementType();
1361 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1363 Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1364 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1365 Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1366 Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1367 Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1368 Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1369 desc = InsertOp::create(rewriter, loc, fma, desc, i);
1378 static std::optional<SmallVector<int64_t, 4>>
1379 computeContiguousStrides(MemRefType memRefType) {
1382 if (
failed(memRefType.getStridesAndOffset(strides, offset)))
1383 return std::nullopt;
1384 if (!strides.empty() && strides.back() != 1)
1385 return std::nullopt;
1387 if (memRefType.getLayout().isIdentity())
1394 auto sizes = memRefType.getShape();
1395 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
1396 if (ShapedType::isDynamic(sizes[index + 1]) ||
1397 ShapedType::isDynamic(strides[index]) ||
1398 ShapedType::isDynamic(strides[index + 1]))
1399 return std::nullopt;
1400 if (strides[index] != strides[index + 1] * sizes[index + 1])
1401 return std::nullopt;
1406 class VectorTypeCastOpConversion
1412 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1414 auto loc = castOp->getLoc();
1415 MemRefType sourceMemRefType =
1416 cast<MemRefType>(castOp.getOperand().getType());
1417 MemRefType targetMemRefType = castOp.getType();
1420 if (!sourceMemRefType.hasStaticShape() ||
1421 !targetMemRefType.hasStaticShape())
1424 auto llvmSourceDescriptorTy =
1425 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1426 if (!llvmSourceDescriptorTy)
1430 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1431 typeConverter->convertType(targetMemRefType));
1432 if (!llvmTargetDescriptorTy)
1436 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1439 auto targetStrides = computeContiguousStrides(targetMemRefType);
1443 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1451 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1452 desc.setAllocatedPtr(rewriter, loc, allocated);
1455 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1456 desc.setAlignedPtr(rewriter, loc, ptr);
1459 auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1460 desc.setOffset(rewriter, loc, zero);
1463 for (
const auto &indexedSize :
1465 int64_t index = indexedSize.index();
1468 auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1469 desc.setSize(rewriter, loc, index, size);
1471 (*targetStrides)[index]);
1473 LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1474 desc.setStride(rewriter, loc, index, stride);
1484 class VectorCreateMaskOpConversion
1487 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1488 bool enableIndexOpt)
1490 force32BitVectorIndices(enableIndexOpt) {}
1493 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1495 auto dstType = op.getType();
1496 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1498 IntegerType idxType =
1500 auto loc = op->getLoc();
1501 Value indices = LLVM::StepVectorOp::create(
1506 adaptor.getOperands()[0]);
1507 Value bounds = BroadcastOp::create(rewriter, loc, indices.
getType(), bound);
1508 Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1515 const bool force32BitVectorIndices;
1522 explicit VectorPrintOpConversion(
1526 symbolTables(symbolTables) {}
1542 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1544 auto parent =
printOp->getParentOfType<ModuleOp>();
1550 if (
auto value = adaptor.getSource()) {
1560 auto punct =
printOp.getPunctuation();
1561 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1564 *stringLiteral, *getTypeConverter(),
1566 if (createResult.failed())
1569 }
else if (punct != PrintPunctuation::NoPunctuation) {
1570 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1572 case PrintPunctuation::Close:
1575 case PrintPunctuation::Open:
1578 case PrintPunctuation::Comma:
1581 case PrintPunctuation::NewLine:
1585 llvm_unreachable(
"unexpected punctuation");
1590 emitCall(rewriter,
printOp->getLoc(), op.value());
1598 enum class PrintConversion {
1609 Value value)
const {
1610 if (typeConverter->convertType(
printType) ==
nullptr)
1615 FailureOr<Operation *> printer;
1621 conversion = PrintConversion::Bitcast16;
1624 conversion = PrintConversion::Bitcast16;
1628 }
else if (
auto intTy = dyn_cast<IntegerType>(
printType)) {
1632 unsigned width = intTy.getWidth();
1633 if (intTy.isUnsigned()) {
1636 conversion = PrintConversion::ZeroExt64;
1643 assert(intTy.isSignless() || intTy.isSigned());
1648 conversion = PrintConversion::ZeroExt64;
1649 else if (width < 64)
1650 conversion = PrintConversion::SignExt64;
1663 switch (conversion) {
1664 case PrintConversion::ZeroExt64:
1665 value = arith::ExtUIOp::create(
1668 case PrintConversion::SignExt64:
1669 value = arith::ExtSIOp::create(
1672 case PrintConversion::Bitcast16:
1673 value = LLVM::BitcastOp::create(
1679 emitCall(rewriter, loc, printer.value(), value);
1694 struct VectorBroadcastScalarToLowRankLowering
1699 matchAndRewrite(vector::BroadcastOp
broadcast, OpAdaptor adaptor,
1701 if (isa<VectorType>(
broadcast.getSourceType()))
1703 broadcast,
"broadcast from vector type not handled");
1706 if (resultType.getRank() > 1)
1708 "broadcast to 2+-d handled elsewhere");
1714 auto zero = LLVM::ConstantOp::create(
1720 if (resultType.getRank() == 0) {
1722 broadcast, vectorType, poison, adaptor.getSource(), zero);
1728 LLVM::InsertElementOp::create(rewriter,
broadcast.
getLoc(), vectorType,
1729 poison, adaptor.getSource(), zero);
1745 struct VectorBroadcastScalarToNdLowering
1750 matchAndRewrite(BroadcastOp
broadcast, OpAdaptor adaptor,
1752 if (isa<VectorType>(
broadcast.getSourceType()))
1754 broadcast,
"broadcast from vector type not handled");
1757 if (resultType.getRank() <= 1)
1759 broadcast,
"broadcast to 1-d or 0-d handled elsewhere");
1763 auto vectorTypeInfo =
1765 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1766 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1767 if (!llvmNDVectorTy || !llvm1DVectorTy)
1771 Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1775 Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1776 auto zero = LLVM::ConstantOp::create(
1777 rewriter, loc, typeConverter->convertType(rewriter.
getIntegerType(32)),
1779 Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1780 vdesc, adaptor.getSource(), zero);
1783 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1785 v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1790 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1799 struct VectorInterleaveOpLowering
1804 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1806 VectorType resultType = interleaveOp.getResultVectorType();
1808 if (resultType.getRank() != 1)
1810 "InterleaveOp not rank 1");
1812 if (resultType.isScalable()) {
1814 interleaveOp, typeConverter->convertType(resultType),
1815 adaptor.getLhs(), adaptor.getRhs());
1822 int64_t resultVectorSize = resultType.getNumElements();
1824 interleaveShuffleMask.reserve(resultVectorSize);
1825 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1826 interleaveShuffleMask.push_back(i);
1827 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1830 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1831 interleaveShuffleMask);
1838 struct VectorDeinterleaveOpLowering
1843 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1845 VectorType resultType = deinterleaveOp.getResultVectorType();
1846 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1847 auto loc = deinterleaveOp.getLoc();
1851 if (resultType.getRank() != 1)
1853 "DeinterleaveOp not rank 1");
1855 if (resultType.isScalable()) {
1856 const auto *llvmTypeConverter = this->getTypeConverter();
1857 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1858 auto packedOpResults =
1859 llvmTypeConverter->packOperationResults(deinterleaveResults);
1860 auto intrinsic = LLVM::vector_deinterleave2::create(
1861 rewriter, loc, packedOpResults, adaptor.getSource());
1863 auto evenResult = LLVM::ExtractValueOp::create(
1864 rewriter, loc, intrinsic->getResult(0), 0);
1865 auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1866 intrinsic->getResult(0), 1);
1875 int64_t resultVectorSize = resultType.getNumElements();
1879 evenShuffleMask.reserve(resultVectorSize);
1880 oddShuffleMask.reserve(resultVectorSize);
1882 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1884 evenShuffleMask.push_back(i);
1886 oddShuffleMask.push_back(i);
1889 auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1890 auto evenShuffle = LLVM::ShuffleVectorOp::create(
1891 rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1892 auto oddShuffle = LLVM::ShuffleVectorOp::create(
1893 rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1901 struct VectorFromElementsLowering
1906 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1908 Location loc = fromElementsOp.getLoc();
1909 VectorType vectorType = fromElementsOp.getType();
1913 if (vectorType.getRank() > 1)
1915 "rank > 1 vectors are not supported");
1916 Type llvmType = typeConverter->convertType(vectorType);
1918 Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1921 LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1922 result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
1925 rewriter.
replaceOp(fromElementsOp, result);
1931 struct VectorToElementsLowering
1936 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1938 Location loc = toElementsOp.getLoc();
1939 auto idxType = typeConverter->convertType(rewriter.
getIndexType());
1940 Value source = adaptor.getSource();
1943 for (
auto [idx, element] :
llvm::enumerate(toElementsOp.getElements())) {
1945 if (element.use_empty())
1948 auto constIdx = LLVM::ConstantOp::create(
1950 auto llvmType = typeConverter->convertType(element.getType());
1952 Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1954 results[idx] = result;
1957 rewriter.
replaceOp(toElementsOp, results);
1963 struct VectorScalableStepOpLowering
1968 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1970 auto resultType = cast<VectorType>(stepOp.getType());
1971 if (!resultType.isScalable()) {
1974 Type llvmType = typeConverter->convertType(stepOp.getType());
1990 class ContractionOpToMatmulOpLowering
1993 using MaskableOpRewritePattern::MaskableOpRewritePattern;
1995 ContractionOpToMatmulOpLowering(
MLIRContext *context,
2000 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2021 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2022 vector::ContractionOp op, MaskingOpInterface maskOp,
2028 auto iteratorTypes = op.getIteratorTypes().getValue();
2034 Type opResType = op.getType();
2035 VectorType vecType = dyn_cast<VectorType>(opResType);
2036 if (vecType && vecType.isScalable()) {
2041 Type elementType = op.getLhsType().getElementType();
2045 Type dstElementType = vecType ? vecType.getElementType() : opResType;
2046 if (elementType != dstElementType)
2056 Value lhs = op.getLhs();
2057 auto lhsMap = op.getIndexingMapsArray()[0];
2064 Value rhs = op.getRhs();
2065 auto rhsMap = op.getIndexingMapsArray()[1];
2072 VectorType lhsType = cast<VectorType>(lhs.
getType());
2073 VectorType rhsType = cast<VectorType>(rhs.
getType());
2074 int64_t lhsRows = lhsType.getDimSize(0);
2075 int64_t lhsColumns = lhsType.getDimSize(1);
2076 int64_t rhsColumns = rhsType.getDimSize(1);
2078 Type flattenedLHSType =
2080 lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs);
2082 Type flattenedRHSType =
2084 rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs);
2086 Value mul = LLVM::MatrixMultiplyOp::create(
2089 cast<VectorType>(lhs.
getType()).getElementType()),
2090 lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2092 mul = vector::ShapeCastOp::create(
2099 auto accMap = op.getIndexingMapsArray()[2];
2103 llvm_unreachable(
"invalid contraction semantics");
2105 Value res = isa<IntegerType>(elementType)
2106 ?
static_cast<Value>(
2107 arith::AddIOp::create(rew, loc, op.getAcc(), mul))
2108 :
static_cast<Value>(
2109 arith::AddFOp::create(rew, loc, op.getAcc(), mul));
2127 class TransposeOpToMatrixTransposeOpLowering
2132 LogicalResult matchAndRewrite(vector::TransposeOp op,
2134 auto loc = op.getLoc();
2136 Value input = op.getVector();
2137 VectorType inputType = op.getSourceVectorType();
2138 VectorType resType = op.getResultVectorType();
2140 if (inputType.isScalable())
2142 op,
"This lowering does not support scalable vectors");
2147 if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2151 Type flattenedType =
2154 vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2157 Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2158 matrix,
rows, columns);
2173 patterns.add<ContractionOpToMatmulOpLowering>(
patterns.getContext(), benefit);
2178 patterns.add<TransposeOpToMatrixTransposeOpLowering>(
patterns.getContext(),
2185 bool reassociateFPReductions,
bool force32BitVectorIndices,
2186 bool useVectorAlignment) {
2189 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2190 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2191 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2192 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2193 VectorLoadStoreConversion<vector::StoreOp>,
2194 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2195 VectorGatherOpConversion, VectorScatterOpConversion>(
2196 converter, useVectorAlignment);
2197 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2198 VectorExtractOpConversion, VectorFMAOp1DConversion,
2199 VectorInsertOpConversion, VectorPrintOpConversion,
2200 VectorTypeCastOpConversion, VectorScaleOpConversion,
2201 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2202 VectorBroadcastScalarToLowRankLowering,
2203 VectorBroadcastScalarToNdLowering,
2204 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2205 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2206 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2207 VectorToElementsLowering, VectorScalableStepOpLowering>(
2214 void loadDependentDialects(
MLIRContext *context)
const final {
2215 context->loadDialect<LLVM::LLVMDialect>();
2220 void populateConvertToLLVMConversionPatterns(
2231 dialect->addInterfaces<VectorToLLVMDialectInterface>();
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.