31#include "llvm/ADT/APFloat.h"
32#include "llvm/IR/LLVMContext.h"
33#include "llvm/Support/Casting.h"
45 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
47 auto idxType = rewriter.getIndexType();
48 auto constant = LLVM::ConstantOp::create(
49 rewriter, loc, typeConverter.convertType(idxType),
50 rewriter.getIntegerAttr(idxType, pos));
51 return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
54 return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
62 auto idxType = rewriter.getIndexType();
63 auto constant = LLVM::ConstantOp::create(
64 rewriter, loc, typeConverter.convertType(idxType),
65 rewriter.getIntegerAttr(idxType, pos));
66 return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
69 return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
74 VectorType vectorType,
unsigned &align) {
75 Type convertedVectorTy = typeConverter.convertType(vectorType);
76 if (!convertedVectorTy)
79 llvm::LLVMContext llvmContext;
89 MemRefType memrefType,
unsigned &align) {
90 Type elementTy = typeConverter.convertType(memrefType.getElementType());
96 llvm::LLVMContext llvmContext;
109 VectorType vectorType,
110 MemRefType memrefType,
unsigned &align,
111 bool useVectorAlignment) {
112 if (useVectorAlignment) {
127 if (!memRefType.isLastDimUnitStride())
137 MemRefType memRefType,
Value llvmMemref,
Value base,
140 "unsupported memref type");
141 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
145 vectorType.getScalableDims()[0]);
146 return LLVM::GEPOp::create(
147 rewriter, loc, ptrsType,
148 typeConverter.convertType(memRefType.getElementType()), base,
index);
155 if (
auto attr = dyn_cast<Attribute>(foldResult)) {
156 auto intAttr = cast<IntegerAttr>(attr);
157 return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
160 return cast<Value>(foldResult);
166using VectorScaleOpConversion =
170class VectorBitCastOpConversion
173 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
176 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter)
const override {
179 VectorType resultTy = bitCastOp.getResultVectorType();
180 if (resultTy.getRank() > 1)
182 Type newResultTy = typeConverter->convertType(resultTy);
183 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
184 adaptor.getOperands()[0]);
192static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193 vector::LoadOpAdaptor adaptor,
194 VectorType vectorTy,
Value ptr,
unsigned align,
195 ConversionPatternRewriter &rewriter) {
196 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy,
ptr, align,
198 loadOp.getNontemporal());
201static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202 vector::MaskedLoadOpAdaptor adaptor,
203 VectorType vectorTy,
Value ptr,
unsigned align,
204 ConversionPatternRewriter &rewriter) {
205 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206 loadOp, vectorTy,
ptr, adaptor.getMask(), adaptor.getPassThru(), align);
209static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210 vector::StoreOpAdaptor adaptor,
211 VectorType vectorTy,
Value ptr,
unsigned align,
212 ConversionPatternRewriter &rewriter) {
213 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
215 storeOp.getNontemporal());
218static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219 vector::MaskedStoreOpAdaptor adaptor,
220 VectorType vectorTy,
Value ptr,
unsigned align,
221 ConversionPatternRewriter &rewriter) {
222 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
223 storeOp, adaptor.getValueToStore(),
ptr, adaptor.getMask(), align);
228template <
class LoadOrStoreOp>
231 explicit VectorLoadStoreConversion(
const LLVMTypeConverter &typeConv,
233 : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
234 useVectorAlignment(useVectorAlign) {}
235 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
238 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239 typename LoadOrStoreOp::Adaptor adaptor,
240 ConversionPatternRewriter &rewriter)
const override {
242 VectorType vectorTy = loadOrStoreOp.getVectorType();
243 if (vectorTy.getRank() > 1)
246 auto loc = loadOrStoreOp->getLoc();
247 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
251 unsigned align = loadOrStoreOp.getAlignment().value_or(0);
254 memRefTy, align, useVectorAlignment)))
255 return rewriter.notifyMatchFailure(loadOrStoreOp,
256 "could not resolve alignment");
259 auto vtype = cast<VectorType>(
260 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
262 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
263 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
273 const bool useVectorAlignment;
277class VectorGatherOpConversion
280 explicit VectorGatherOpConversion(
const LLVMTypeConverter &typeConv,
282 : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
283 useVectorAlignment(useVectorAlign) {}
284 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
287 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
289 Location loc = gather->getLoc();
290 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
291 assert(memRefType &&
"The base should be bufferized");
295 return rewriter.notifyMatchFailure(gather,
"memref type not supported");
297 VectorType vType = gather.getVectorType();
298 if (vType.getRank() > 1) {
299 return rewriter.notifyMatchFailure(
300 gather,
"only 1-D vectors can be lowered to LLVM");
305 unsigned align = gather.getAlignment().value_or(0);
308 memRefType, align, useVectorAlignment)))
309 return rewriter.notifyMatchFailure(gather,
"could not resolve alignment");
313 adaptor.getBase(), adaptor.getOffsets());
314 Value base = adaptor.getBase();
316 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
317 base, ptr, adaptor.getIndices(), vType);
320 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
321 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
322 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
331 const bool useVectorAlignment;
335class VectorScatterOpConversion
338 explicit VectorScatterOpConversion(
const LLVMTypeConverter &typeConv,
340 : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
341 useVectorAlignment(useVectorAlign) {}
343 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
346 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
347 ConversionPatternRewriter &rewriter)
const override {
348 auto loc = scatter->getLoc();
349 auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
350 assert(memRefType &&
"The base should be bufferized");
354 return rewriter.notifyMatchFailure(scatter,
"memref type not supported");
356 VectorType vType = scatter.getVectorType();
357 if (vType.getRank() > 1) {
358 return rewriter.notifyMatchFailure(
359 scatter,
"only 1-D vectors can be lowered to LLVM");
364 unsigned align = scatter.getAlignment().value_or(0);
367 memRefType, align, useVectorAlignment)))
368 return rewriter.notifyMatchFailure(scatter,
369 "could not resolve alignment");
373 adaptor.getBase(), adaptor.getOffsets());
375 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
376 adaptor.getBase(), ptr, adaptor.getIndices(), vType);
379 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
380 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
381 rewriter.getI32IntegerAttr(align));
390 const bool useVectorAlignment;
394class VectorExpandLoadOpConversion
397 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
400 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
401 ConversionPatternRewriter &rewriter)
const override {
402 auto loc = expand->getLoc();
403 MemRefType memRefType = expand.getMemRefType();
406 auto vtype = typeConverter->convertType(expand.getVectorType());
408 adaptor.getBase(), adaptor.getIndices());
413 uint64_t alignment = expand.getAlignment().value_or(1);
415 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
416 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
423class VectorCompressStoreOpConversion
426 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
429 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
430 ConversionPatternRewriter &rewriter)
const override {
431 auto loc = compress->getLoc();
432 MemRefType memRefType = compress.getMemRefType();
436 adaptor.getBase(), adaptor.getIndices());
441 uint64_t alignment = compress.getAlignment().value_or(1);
443 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
444 compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
450class ReductionNeutralZero {};
451class ReductionNeutralIntOne {};
452class ReductionNeutralFPOne {};
453class ReductionNeutralAllOnes {};
454class ReductionNeutralSIntMin {};
455class ReductionNeutralUIntMin {};
456class ReductionNeutralSIntMax {};
457class ReductionNeutralUIntMax {};
458class ReductionNeutralFPMin {};
459class ReductionNeutralFPMax {};
463 ConversionPatternRewriter &rewriter,
465 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
466 rewriter.getZeroAttr(llvmType));
471 ConversionPatternRewriter &rewriter,
473 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
474 rewriter.getIntegerAttr(llvmType, 1));
479 ConversionPatternRewriter &rewriter,
481 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
482 rewriter.getFloatAttr(llvmType, 1.0));
487 ConversionPatternRewriter &rewriter,
489 return LLVM::ConstantOp::create(
490 rewriter, loc, llvmType,
491 rewriter.getIntegerAttr(
497 ConversionPatternRewriter &rewriter,
499 return LLVM::ConstantOp::create(
500 rewriter, loc, llvmType,
501 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
507 ConversionPatternRewriter &rewriter,
509 return LLVM::ConstantOp::create(
510 rewriter, loc, llvmType,
511 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
517 ConversionPatternRewriter &rewriter,
519 return LLVM::ConstantOp::create(
520 rewriter, loc, llvmType,
521 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
527 ConversionPatternRewriter &rewriter,
529 return LLVM::ConstantOp::create(
530 rewriter, loc, llvmType,
531 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
537 ConversionPatternRewriter &rewriter,
539 auto floatType = cast<FloatType>(llvmType);
540 return LLVM::ConstantOp::create(
541 rewriter, loc, llvmType,
542 rewriter.getFloatAttr(
543 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
549 ConversionPatternRewriter &rewriter,
551 auto floatType = cast<FloatType>(llvmType);
552 return LLVM::ConstantOp::create(
553 rewriter, loc, llvmType,
554 rewriter.getFloatAttr(
555 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
561template <
class ReductionNeutral>
562static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
575static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
577 VectorType vType = cast<VectorType>(llvmType);
578 auto vShape = vType.getShape();
579 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
581 Value baseVecLength = LLVM::ConstantOp::create(
582 rewriter, loc, rewriter.getI32Type(),
583 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
585 if (!vType.getScalableDims()[0])
586 return baseVecLength;
589 Value vScale = vector::VectorScaleOp::create(rewriter, loc);
591 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
592 Value scalableVecLength =
593 arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
594 return scalableVecLength;
601template <
class LLVMRedIntrinOp,
class ScalarOp>
602static Value createIntegerReductionArithmeticOpLowering(
603 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
607 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
610 result = ScalarOp::create(rewriter, loc, accumulator,
result);
618template <
class LLVMRedIntrinOp>
619static Value createIntegerReductionComparisonOpLowering(
620 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
621 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
623 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
626 LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator,
result);
627 result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator,
result);
633template <
typename Source>
634struct VectorToScalarMapper;
636struct VectorToScalarMapper<
LLVM::vector_reduce_fmaximum> {
637 using Type = LLVM::MaximumOp;
640struct VectorToScalarMapper<
LLVM::vector_reduce_fminimum> {
641 using Type = LLVM::MinimumOp;
644struct VectorToScalarMapper<
LLVM::vector_reduce_fmax> {
645 using Type = LLVM::MaxNumOp;
648struct VectorToScalarMapper<
LLVM::vector_reduce_fmin> {
649 using Type = LLVM::MinNumOp;
653template <
class LLVMRedIntrinOp>
654static Value createFPReductionComparisonOpLowering(
655 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
656 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
658 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
661 result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
662 rewriter, loc,
result, accumulator);
669class MaskNeutralFMaximum {};
670class MaskNeutralFMinimum {};
674getMaskNeutralValue(MaskNeutralFMaximum,
675 const llvm::fltSemantics &floatSemantics) {
676 return llvm::APFloat::getSmallest(floatSemantics,
true);
680getMaskNeutralValue(MaskNeutralFMinimum,
681 const llvm::fltSemantics &floatSemantics) {
682 return llvm::APFloat::getLargest(floatSemantics,
false);
686template <
typename MaskNeutral>
687static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
690 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
691 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
693 return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
700template <
class LLVMRedIntrinOp,
class MaskNeutral>
702lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
705 Value mask, LLVM::FastmathFlagsAttr fmf) {
706 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
707 rewriter, loc, llvmType, vectorOperand.
getType());
708 const Value selectedVectorByMask = LLVM::SelectOp::create(
709 rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
710 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
711 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
714template <
class LLVMRedIntrinOp,
class ReductionNeutral>
716lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc,
718 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
719 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
720 llvmType, accumulator);
721 return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
722 accumulator, vectorOperand,
729template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
731lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
734 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
735 llvmType, accumulator);
736 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
737 accumulator, vectorOperand);
740template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
741static Value lowerPredicatedReductionWithStartValue(
742 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
744 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
745 llvmType, accumulator);
747 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
748 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
749 accumulator, vectorOperand,
753template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
754 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
755static Value lowerPredicatedReductionWithStartValue(
756 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
759 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
760 IntReductionNeutral>(
761 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
764 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
766 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
770class VectorReductionOpConversion
773 explicit VectorReductionOpConversion(
const LLVMTypeConverter &typeConv,
774 bool reassociateFPRed)
775 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
776 reassociateFPReductions(reassociateFPRed) {}
779 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
780 ConversionPatternRewriter &rewriter)
const override {
781 auto kind = reductionOp.getKind();
782 Type eltType = reductionOp.getDest().getType();
783 Type llvmType = typeConverter->convertType(eltType);
784 Value operand = adaptor.getVector();
785 Value acc = adaptor.getAcc();
786 Location loc = reductionOp.getLoc();
792 case vector::CombiningKind::ADD:
794 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
796 rewriter, loc, llvmType, operand, acc);
798 case vector::CombiningKind::MUL:
800 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
802 rewriter, loc, llvmType, operand, acc);
804 case vector::CombiningKind::MINUI:
805 result = createIntegerReductionComparisonOpLowering<
806 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
807 LLVM::ICmpPredicate::ule);
809 case vector::CombiningKind::MINSI:
810 result = createIntegerReductionComparisonOpLowering<
811 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
812 LLVM::ICmpPredicate::sle);
814 case vector::CombiningKind::MAXUI:
815 result = createIntegerReductionComparisonOpLowering<
816 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
817 LLVM::ICmpPredicate::uge);
819 case vector::CombiningKind::MAXSI:
820 result = createIntegerReductionComparisonOpLowering<
821 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
822 LLVM::ICmpPredicate::sge);
824 case vector::CombiningKind::AND:
826 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
828 rewriter, loc, llvmType, operand, acc);
830 case vector::CombiningKind::OR:
832 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
834 rewriter, loc, llvmType, operand, acc);
836 case vector::CombiningKind::XOR:
838 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
840 rewriter, loc, llvmType, operand, acc);
845 rewriter.replaceOp(reductionOp,
result);
850 if (!isa<FloatType>(eltType))
853 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
854 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
855 reductionOp.getContext(),
857 fmf = LLVM::FastmathFlagsAttr::get(
858 reductionOp.getContext(),
859 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
860 : LLVM::FastmathFlags::none));
864 if (kind == vector::CombiningKind::ADD) {
865 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
866 ReductionNeutralZero>(
867 rewriter, loc, llvmType, operand, acc, fmf);
868 }
else if (kind == vector::CombiningKind::MUL) {
869 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
870 ReductionNeutralFPOne>(
871 rewriter, loc, llvmType, operand, acc, fmf);
872 }
else if (kind == vector::CombiningKind::MINIMUMF) {
874 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
875 rewriter, loc, llvmType, operand, acc, fmf);
876 }
else if (kind == vector::CombiningKind::MAXIMUMF) {
878 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
879 rewriter, loc, llvmType, operand, acc, fmf);
880 }
else if (kind == vector::CombiningKind::MINNUMF) {
881 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
882 rewriter, loc, llvmType, operand, acc, fmf);
883 }
else if (kind == vector::CombiningKind::MAXNUMF) {
884 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
885 rewriter, loc, llvmType, operand, acc, fmf);
890 rewriter.replaceOp(reductionOp,
result);
895 const bool reassociateFPReductions;
906template <
class MaskedOp>
907class VectorMaskOpConversionBase
910 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
913 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
914 ConversionPatternRewriter &rewriter)
const final {
916 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
919 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
923 virtual LogicalResult
924 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
925 vector::MaskableOpInterface maskableOp,
926 ConversionPatternRewriter &rewriter)
const = 0;
929class MaskedReductionOpConversion
930 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
933 using VectorMaskOpConversionBase<
934 vector::ReductionOp>::VectorMaskOpConversionBase;
936 LogicalResult matchAndRewriteMaskableOp(
937 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
938 ConversionPatternRewriter &rewriter)
const override {
939 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
940 auto kind = reductionOp.getKind();
941 Type eltType = reductionOp.getDest().getType();
942 Type llvmType = typeConverter->convertType(eltType);
943 Value operand = reductionOp.getVector();
944 Value acc = reductionOp.getAcc();
945 Location loc = reductionOp.getLoc();
947 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
948 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
949 reductionOp.getContext(),
954 case vector::CombiningKind::ADD:
955 result = lowerPredicatedReductionWithStartValue<
956 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
957 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
960 case vector::CombiningKind::MUL:
961 result = lowerPredicatedReductionWithStartValue<
962 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
963 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
966 case vector::CombiningKind::MINUI:
967 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
968 ReductionNeutralUIntMax>(
969 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
971 case vector::CombiningKind::MINSI:
972 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
973 ReductionNeutralSIntMax>(
974 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
976 case vector::CombiningKind::MAXUI:
977 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
978 ReductionNeutralUIntMin>(
979 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
981 case vector::CombiningKind::MAXSI:
982 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
983 ReductionNeutralSIntMin>(
984 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
986 case vector::CombiningKind::AND:
987 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
988 ReductionNeutralAllOnes>(
989 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
991 case vector::CombiningKind::OR:
992 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
993 ReductionNeutralZero>(
994 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
996 case vector::CombiningKind::XOR:
997 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
998 ReductionNeutralZero>(
999 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1001 case vector::CombiningKind::MINNUMF:
1002 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1003 ReductionNeutralFPMax>(
1004 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1006 case vector::CombiningKind::MAXNUMF:
1007 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1008 ReductionNeutralFPMin>(
1009 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1011 case CombiningKind::MAXIMUMF:
1012 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1013 MaskNeutralFMaximum>(
1014 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1016 case CombiningKind::MINIMUMF:
1017 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1018 MaskNeutralFMinimum>(
1019 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1024 rewriter.replaceOp(maskOp,
result);
1029class VectorShuffleOpConversion
1032 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
1035 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1036 ConversionPatternRewriter &rewriter)
const override {
1037 auto loc = shuffleOp->getLoc();
1038 auto v1Type = shuffleOp.getV1VectorType();
1039 auto v2Type = shuffleOp.getV2VectorType();
1040 auto vectorType = shuffleOp.getResultVectorType();
1041 Type llvmType = typeConverter->convertType(vectorType);
1042 ArrayRef<int64_t> mask = shuffleOp.getMask();
1049 int64_t rank = vectorType.getRank();
1051 bool wellFormed0DCase =
1052 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1053 bool wellFormedNDCase =
1054 v1Type.getRank() == rank && v2Type.getRank() == rank;
1055 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1060 if (rank <= 1 && v1Type == v2Type) {
1061 Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1062 rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1063 llvm::to_vector_of<int32_t>(mask));
1064 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1069 int64_t v1Dim = v1Type.getDimSize(0);
1071 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1072 eltType = arrayType.getElementType();
1074 eltType = cast<VectorType>(llvmType).getElementType();
1075 Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1077 for (int64_t extPos : mask) {
1078 Value value = adaptor.getV1();
1079 if (extPos >= v1Dim) {
1081 value = adaptor.getV2();
1083 Value extract =
extractOne(rewriter, *getTypeConverter(), loc, value,
1084 eltType, rank, extPos);
1085 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1086 llvmType, rank, insPos++);
1088 rewriter.replaceOp(shuffleOp, insert);
1093class VectorExtractOpConversion
1096 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1099 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1100 ConversionPatternRewriter &rewriter)
const override {
1101 auto loc = extractOp->getLoc();
1102 auto resultType = extractOp.getResult().getType();
1103 auto llvmResultType = typeConverter->convertType(resultType);
1105 if (!llvmResultType)
1109 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1123 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1127 bool extractsScalar =
static_cast<int64_t
>(positionVec.size()) ==
1128 extractOp.getSourceVectorType().getRank();
1132 if (extractOp.getSourceVectorType().getRank() == 0) {
1133 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1134 positionVec.push_back(rewriter.getZeroAttr(idxType));
1137 Value extracted = adaptor.getSource();
1138 if (extractsAggregate) {
1139 ArrayRef<OpFoldResult> position(positionVec);
1140 if (extractsScalar) {
1144 position = position.drop_back();
1147 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1150 extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1154 if (extractsScalar) {
1155 extracted = LLVM::ExtractElementOp::create(
1156 rewriter, loc, extracted,
1160 rewriter.replaceOp(extractOp, extracted);
1181 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1184 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1185 ConversionPatternRewriter &rewriter)
const override {
1186 VectorType vType = fmaOp.getVectorType();
1187 if (vType.getRank() > 1)
1190 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1191 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1196class VectorInsertOpConversion
1199 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1202 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1203 ConversionPatternRewriter &rewriter)
const override {
1204 auto loc = insertOp->getLoc();
1205 auto destVectorType = insertOp.getDestVectorType();
1206 auto llvmResultType = typeConverter->convertType(destVectorType);
1208 if (!llvmResultType)
1212 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1234 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1236 bool insertIntoInnermostDim =
1237 static_cast<int64_t
>(positionVec.size()) == destVectorType.getRank();
1239 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1240 positionVec.begin(),
1241 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1242 OpFoldResult positionOfScalarWithin1DVector;
1243 if (destVectorType.getRank() == 0) {
1246 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1247 positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1248 }
else if (insertIntoInnermostDim) {
1249 positionOfScalarWithin1DVector = positionVec.back();
1255 Value sourceAggregate = adaptor.getValueToStore();
1256 if (insertIntoInnermostDim) {
1259 if (isNestedAggregate) {
1262 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1263 llvm::IsaPred<Attribute>)) {
1267 sourceAggregate = LLVM::ExtractValueOp::create(
1268 rewriter, loc, adaptor.getDest(),
1273 sourceAggregate = adaptor.getDest();
1276 sourceAggregate = LLVM::InsertElementOp::create(
1277 rewriter, loc, sourceAggregate.
getType(), sourceAggregate,
1278 adaptor.getValueToStore(),
1282 Value
result = sourceAggregate;
1283 if (isNestedAggregate) {
1284 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1285 llvm::IsaPred<Attribute>)) {
1289 result = LLVM::InsertValueOp::create(
1290 rewriter, loc, adaptor.getDest(), sourceAggregate,
1294 rewriter.replaceOp(insertOp,
result);
1300struct VectorScalableInsertOpLowering
1302 using ConvertOpToLLVMPattern<
1303 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1306 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1307 ConversionPatternRewriter &rewriter)
const override {
1308 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1309 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1315struct VectorScalableExtractOpLowering
1317 using ConvertOpToLLVMPattern<
1318 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1321 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1322 ConversionPatternRewriter &rewriter)
const override {
1323 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1324 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1325 adaptor.getSource(), adaptor.getPos());
1358 setHasBoundedRewriteRecursion();
1361 LogicalResult matchAndRewrite(FMAOp op,
1362 PatternRewriter &rewriter)
const override {
1363 auto vType = op.getVectorType();
1364 if (vType.getRank() < 2)
1367 auto loc = op.getLoc();
1368 auto elemType = vType.getElementType();
1369 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1371 Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1372 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1373 Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1374 Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1375 Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1376 Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1377 desc = InsertOp::create(rewriter, loc, fma, desc, i);
1386static std::optional<SmallVector<int64_t, 4>>
1387computeContiguousStrides(MemRefType memRefType) {
1390 if (
failed(memRefType.getStridesAndOffset(strides, offset)))
1391 return std::nullopt;
1392 if (!strides.empty() && strides.back() != 1)
1393 return std::nullopt;
1395 if (memRefType.getLayout().isIdentity())
1402 auto sizes = memRefType.getShape();
1404 if (ShapedType::isDynamic(sizes[
index + 1]) ||
1405 ShapedType::isDynamic(strides[
index]) ||
1406 ShapedType::isDynamic(strides[
index + 1]))
1407 return std::nullopt;
1409 return std::nullopt;
1414class VectorTypeCastOpConversion
1417 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1420 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1421 ConversionPatternRewriter &rewriter)
const override {
1422 auto loc = castOp->getLoc();
1423 MemRefType sourceMemRefType =
1424 cast<MemRefType>(castOp.getOperand().getType());
1425 MemRefType targetMemRefType = castOp.getType();
1428 if (!sourceMemRefType.hasStaticShape() ||
1429 !targetMemRefType.hasStaticShape())
1432 auto llvmSourceDescriptorTy =
1433 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1434 if (!llvmSourceDescriptorTy)
1436 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1438 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1439 typeConverter->convertType(targetMemRefType));
1440 if (!llvmTargetDescriptorTy)
1444 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1447 auto targetStrides = computeContiguousStrides(targetMemRefType);
1451 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1454 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1457 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1459 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1460 desc.setAllocatedPtr(rewriter, loc, allocated);
1463 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1464 desc.setAlignedPtr(rewriter, loc, ptr);
1466 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1467 auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1468 desc.setOffset(rewriter, loc, zero);
1471 for (
const auto &indexedSize :
1472 llvm::enumerate(targetMemRefType.getShape())) {
1473 int64_t index = indexedSize.index();
1475 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1476 auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1477 desc.setSize(rewriter, loc, index, size);
1478 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1479 (*targetStrides)[index]);
1481 LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1482 desc.setStride(rewriter, loc, index, stride);
1485 rewriter.replaceOp(castOp, {desc});
1492class VectorCreateMaskOpConversion
1493 :
public OpConversionPattern<vector::CreateMaskOp> {
1495 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1496 bool enableIndexOpt)
1497 : OpConversionPattern<vector::CreateMaskOp>(context),
1498 force32BitVectorIndices(enableIndexOpt) {}
1501 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1502 ConversionPatternRewriter &rewriter)
const override {
1503 auto dstType = op.getType();
1504 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1506 IntegerType idxType =
1507 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1508 auto loc = op->getLoc();
1509 Value
indices = LLVM::StepVectorOp::create(
1511 LLVM::getVectorType(idxType, dstType.getShape()[0],
1514 adaptor.getOperands()[0]);
1515 Value bounds = BroadcastOp::create(rewriter, loc,
indices.getType(), bound);
1516 Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1518 rewriter.replaceOp(op, comp);
1523 const bool force32BitVectorIndices;
1527 SymbolTableCollection *symbolTables =
nullptr;
1530 explicit VectorPrintOpConversion(
1531 const LLVMTypeConverter &typeConverter,
1532 SymbolTableCollection *symbolTables =
nullptr)
1533 : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1534 symbolTables(symbolTables) {}
1550 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1551 ConversionPatternRewriter &rewriter)
const override {
1552 auto parent =
printOp->getParentOfType<ModuleOp>();
1558 if (
auto value = adaptor.getSource()) {
1560 if (isa<VectorType>(printType)) {
1564 if (
failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1568 auto punct =
printOp.getPunctuation();
1569 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1571 LLVM::createPrintStrCall(rewriter, loc, parent,
"vector_print_str",
1572 *stringLiteral, *getTypeConverter(),
1574 if (createResult.failed())
1577 }
else if (punct != PrintPunctuation::NoPunctuation) {
1578 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1580 case PrintPunctuation::Close:
1581 return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
1583 case PrintPunctuation::Open:
1584 return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
1586 case PrintPunctuation::Comma:
1587 return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
1589 case PrintPunctuation::NewLine:
1590 return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
1593 llvm_unreachable(
"unexpected punctuation");
1598 emitCall(rewriter,
printOp->getLoc(), op.value());
1606 enum class PrintConversion {
1615 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1616 ModuleOp parent, Location loc, Type printType,
1617 Value value)
const {
1618 if (typeConverter->convertType(printType) ==
nullptr)
1622 PrintConversion conversion = PrintConversion::None;
1623 FailureOr<Operation *> printer;
1625 printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
1627 printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
1629 conversion = PrintConversion::Bitcast16;
1630 printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
1632 conversion = PrintConversion::Bitcast16;
1633 printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
1635 printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1636 }
else if (
auto intTy = dyn_cast<IntegerType>(printType)) {
1640 unsigned width = intTy.getWidth();
1641 if (intTy.isUnsigned()) {
1644 conversion = PrintConversion::ZeroExt64;
1646 LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1651 assert(intTy.isSignless() || intTy.isSigned());
1656 conversion = PrintConversion::ZeroExt64;
1657 else if (width < 64)
1658 conversion = PrintConversion::SignExt64;
1660 LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
1665 }
else if (
auto floatTy = dyn_cast<FloatType>(printType)) {
1668 llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
1669 Value semValue = LLVM::ConstantOp::create(
1670 rewriter, loc, rewriter.getI32Type(),
1671 rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
1673 LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
1675 LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
1676 emitCall(rewriter, loc, printer.value(),
1685 switch (conversion) {
1686 case PrintConversion::ZeroExt64:
1687 value = arith::ExtUIOp::create(
1688 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1690 case PrintConversion::SignExt64:
1691 value = arith::ExtSIOp::create(
1692 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1694 case PrintConversion::Bitcast16:
1695 value = LLVM::BitcastOp::create(
1696 rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
1698 case PrintConversion::None:
1701 emitCall(rewriter, loc, printer.value(), value);
1706 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1708 LLVM::CallOp::create(rewriter, loc,
TypeRange(), SymbolRefAttr::get(ref),
1716struct VectorBroadcastScalarToLowRankLowering
1718 using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
1721 matchAndRewrite(vector::BroadcastOp
broadcast, OpAdaptor adaptor,
1722 ConversionPatternRewriter &rewriter)
const override {
1723 if (isa<VectorType>(
broadcast.getSourceType()))
1724 return rewriter.notifyMatchFailure(
1725 broadcast,
"broadcast from vector type not handled");
1728 if (resultType.getRank() > 1)
1729 return rewriter.notifyMatchFailure(
broadcast,
1730 "broadcast to 2+-d handled elsewhere");
1736 auto zero = LLVM::ConstantOp::create(
1738 typeConverter->convertType(rewriter.getIntegerType(32)),
1739 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1742 if (resultType.getRank() == 0) {
1743 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1744 broadcast, vectorType, poison, adaptor.getSource(), zero);
1749 LLVM::InsertElementOp::create(rewriter,
broadcast.
getLoc(), vectorType,
1750 poison, adaptor.getSource(), zero);
1754 SmallVector<int32_t> zeroValues(width, 0);
1757 auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
1768struct VectorBroadcastScalarToNdLowering
1770 using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
1773 matchAndRewrite(BroadcastOp
broadcast, OpAdaptor adaptor,
1774 ConversionPatternRewriter &rewriter)
const override {
1775 if (isa<VectorType>(
broadcast.getSourceType()))
1776 return rewriter.notifyMatchFailure(
1777 broadcast,
"broadcast from vector type not handled");
1780 if (resultType.getRank() <= 1)
1781 return rewriter.notifyMatchFailure(
1782 broadcast,
"broadcast to 1-d or 0-d handled elsewhere");
1786 auto vectorTypeInfo =
1788 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1789 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1790 if (!llvmNDVectorTy || !llvm1DVectorTy)
1794 Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1798 Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1799 auto zero = LLVM::ConstantOp::create(
1800 rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1801 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1802 Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1803 vdesc, adaptor.getSource(), zero);
1806 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1807 SmallVector<int32_t> zeroValues(width, 0);
1808 v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1812 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1813 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1822struct VectorInterleaveOpLowering
1827 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1828 ConversionPatternRewriter &rewriter)
const override {
1829 VectorType resultType = interleaveOp.getResultVectorType();
1831 if (resultType.getRank() != 1)
1832 return rewriter.notifyMatchFailure(interleaveOp,
1833 "InterleaveOp not rank 1");
1835 if (resultType.isScalable()) {
1836 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1837 interleaveOp, typeConverter->convertType(resultType),
1838 adaptor.getLhs(), adaptor.getRhs());
1845 int64_t resultVectorSize = resultType.getNumElements();
1846 SmallVector<int32_t> interleaveShuffleMask;
1847 interleaveShuffleMask.reserve(resultVectorSize);
1848 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1849 interleaveShuffleMask.push_back(i);
1850 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1852 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1853 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1854 interleaveShuffleMask);
1861struct VectorDeinterleaveOpLowering
1866 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1867 ConversionPatternRewriter &rewriter)
const override {
1868 VectorType resultType = deinterleaveOp.getResultVectorType();
1869 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1870 auto loc = deinterleaveOp.getLoc();
1874 if (resultType.getRank() != 1)
1875 return rewriter.notifyMatchFailure(deinterleaveOp,
1876 "DeinterleaveOp not rank 1");
1878 if (resultType.isScalable()) {
1879 const auto *llvmTypeConverter = this->getTypeConverter();
1880 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1881 auto packedOpResults =
1882 llvmTypeConverter->packOperationResults(deinterleaveResults);
1883 auto intrinsic = LLVM::vector_deinterleave2::create(
1884 rewriter, loc, packedOpResults, adaptor.getSource());
1886 auto evenResult = LLVM::ExtractValueOp::create(
1887 rewriter, loc, intrinsic->getResult(0), 0);
1888 auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1889 intrinsic->getResult(0), 1);
1891 rewriter.replaceOp(deinterleaveOp,
ValueRange{evenResult, oddResult});
1898 int64_t resultVectorSize = resultType.getNumElements();
1899 SmallVector<int32_t> evenShuffleMask;
1900 SmallVector<int32_t> oddShuffleMask;
1902 evenShuffleMask.reserve(resultVectorSize);
1903 oddShuffleMask.reserve(resultVectorSize);
1905 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1907 evenShuffleMask.push_back(i);
1909 oddShuffleMask.push_back(i);
1912 auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1913 auto evenShuffle = LLVM::ShuffleVectorOp::create(
1914 rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1915 auto oddShuffle = LLVM::ShuffleVectorOp::create(
1916 rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1918 rewriter.replaceOp(deinterleaveOp,
ValueRange{evenShuffle, oddShuffle});
1924struct VectorFromElementsLowering
1929 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1930 ConversionPatternRewriter &rewriter)
const override {
1931 Location loc = fromElementsOp.getLoc();
1932 VectorType vectorType = fromElementsOp.getType();
1936 if (vectorType.getRank() > 1)
1937 return rewriter.notifyMatchFailure(fromElementsOp,
1938 "rank > 1 vectors are not supported");
1939 Type llvmType = typeConverter->convertType(vectorType);
1940 Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1941 Value
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1942 for (
auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1944 LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1945 result = LLVM::InsertElementOp::create(rewriter, loc, llvmType,
result,
1948 rewriter.replaceOp(fromElementsOp,
result);
1954struct VectorToElementsLowering
1959 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1960 ConversionPatternRewriter &rewriter)
const override {
1961 Location loc = toElementsOp.getLoc();
1962 auto idxType = typeConverter->convertType(rewriter.getIndexType());
1963 Value source = adaptor.getSource();
1965 SmallVector<Value> results(toElementsOp->getNumResults());
1966 for (
auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1968 if (element.use_empty())
1971 auto constIdx = LLVM::ConstantOp::create(
1972 rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
1973 auto llvmType = typeConverter->convertType(element.getType());
1975 Value
result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1980 rewriter.replaceOp(toElementsOp, results);
1986struct VectorScalableStepOpLowering
1991 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1992 ConversionPatternRewriter &rewriter)
const override {
1993 auto resultType = cast<VectorType>(stepOp.getType());
1994 if (!resultType.isScalable()) {
1997 Type llvmType = typeConverter->convertType(stepOp.getType());
1998 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
2013class ContractionOpToMatmulOpLowering
2016 using MaskableOpRewritePattern::MaskableOpRewritePattern;
2018 ContractionOpToMatmulOpLowering(MLIRContext *context,
2019 PatternBenefit benefit = 100)
2020 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2023 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2024 PatternRewriter &rewriter)
const override;
2044FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2045 vector::ContractionOp op, MaskingOpInterface maskOp,
2051 auto iteratorTypes = op.getIteratorTypes().getValue();
2057 Type opResType = op.getType();
2058 VectorType vecType = dyn_cast<VectorType>(opResType);
2059 if (vecType && vecType.isScalable()) {
2064 Type elementType = op.getLhsType().getElementType();
2068 Type dstElementType = vecType ? vecType.getElementType() : opResType;
2069 if (elementType != dstElementType)
2074 MLIRContext *ctx = op.getContext();
2075 Location loc = op.getLoc();
2079 Value
lhs = op.getLhs();
2080 auto lhsMap = op.getIndexingMapsArray()[0];
2082 lhs = vector::TransposeOp::create(rew, loc,
lhs, ArrayRef<int64_t>{1, 0});
2087 Value
rhs = op.getRhs();
2088 auto rhsMap = op.getIndexingMapsArray()[1];
2090 rhs = vector::TransposeOp::create(rew, loc,
rhs, ArrayRef<int64_t>{1, 0});
2095 VectorType lhsType = cast<VectorType>(
lhs.getType());
2096 VectorType rhsType = cast<VectorType>(
rhs.getType());
2097 int64_t lhsRows = lhsType.getDimSize(0);
2098 int64_t lhsColumns = lhsType.getDimSize(1);
2099 int64_t rhsColumns = rhsType.getDimSize(1);
2101 Type flattenedLHSType =
2102 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2103 lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType,
lhs);
2105 Type flattenedRHSType =
2106 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2107 rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType,
rhs);
2109 Value
mul = LLVM::MatrixMultiplyOp::create(
2111 VectorType::get(lhsRows * rhsColumns,
2112 cast<VectorType>(
lhs.getType()).getElementType()),
2113 lhs,
rhs, lhsRows, lhsColumns, rhsColumns);
2115 mul = vector::ShapeCastOp::create(
2117 VectorType::get({lhsRows, rhsColumns},
2122 auto accMap = op.getIndexingMapsArray()[2];
2124 mul = vector::TransposeOp::create(rew, loc,
mul, ArrayRef<int64_t>{1, 0});
2126 llvm_unreachable(
"invalid contraction semantics");
2128 Value res = isa<IntegerType>(elementType)
2129 ?
static_cast<Value
>(
2130 arith::AddIOp::create(rew, loc, op.getAcc(),
mul))
2131 : static_cast<Value>(
2132 arith::AddFOp::create(rew, loc, op.getAcc(),
mul));
2150class TransposeOpToMatrixTransposeOpLowering
2151 :
public OpRewritePattern<vector::TransposeOp> {
2155 LogicalResult matchAndRewrite(vector::TransposeOp op,
2156 PatternRewriter &rewriter)
const override {
2157 auto loc = op.getLoc();
2159 Value input = op.getVector();
2160 VectorType inputType = op.getSourceVectorType();
2161 VectorType resType = op.getResultVectorType();
2163 if (inputType.isScalable())
2165 op,
"This lowering does not support scalable vectors");
2168 ArrayRef<int64_t> transp = op.getPermutation();
2170 if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2174 Type flattenedType =
2175 VectorType::get(resType.getNumElements(), resType.getElementType());
2177 vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2180 Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2181 matrix, rows, columns);
2191 patterns.
add<VectorFMAOpNDRewritePattern>(patterns.
getContext());
2196 patterns.
add<ContractionOpToMatmulOpLowering>(patterns.
getContext(), benefit);
2201 patterns.
add<TransposeOpToMatrixTransposeOpLowering>(patterns.
getContext(),
2208 bool reassociateFPReductions,
bool force32BitVectorIndices,
2209 bool useVectorAlignment) {
2212 patterns.
add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2213 patterns.
add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2214 patterns.
add<VectorLoadStoreConversion<vector::LoadOp>,
2215 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2216 VectorLoadStoreConversion<vector::StoreOp>,
2217 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2218 VectorGatherOpConversion, VectorScatterOpConversion>(
2219 converter, useVectorAlignment);
2220 patterns.
add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2221 VectorExtractOpConversion, VectorFMAOp1DConversion,
2222 VectorInsertOpConversion, VectorPrintOpConversion,
2223 VectorTypeCastOpConversion, VectorScaleOpConversion,
2224 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2225 VectorBroadcastScalarToLowRankLowering,
2226 VectorBroadcastScalarToNdLowering,
2227 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2228 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2229 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2230 VectorToElementsLowering, VectorScalableStepOpLowering>(
2235struct VectorToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
2236 VectorToLLVMDialectInterface(
Dialect *dialect)
2237 : ConvertToLLVMPatternInterface(dialect) {}
2239 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
2240 void loadDependentDialects(MLIRContext *context)
const final {
2241 context->loadDialect<LLVM::LLVMDialect>();
2246 void populateConvertToLLVMConversionPatterns(
2247 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
2248 RewritePatternSet &patterns)
const final {
2257 dialect->addInterfaces<VectorToLLVMDialectInterface>();
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Conversion from types to the LLVM IR dialect.
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.