31#include "llvm/ADT/APFloat.h"
32#include "llvm/IR/LLVMContext.h"
33#include "llvm/Support/Casting.h"
45 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
47 auto idxType = rewriter.getIndexType();
48 auto constant = LLVM::ConstantOp::create(
49 rewriter, loc, typeConverter.convertType(idxType),
50 rewriter.getIntegerAttr(idxType, pos));
51 return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
54 return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
62 auto idxType = rewriter.getIndexType();
63 auto constant = LLVM::ConstantOp::create(
64 rewriter, loc, typeConverter.convertType(idxType),
65 rewriter.getIntegerAttr(idxType, pos));
66 return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
69 return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
74 VectorType vectorType,
unsigned &align) {
75 Type convertedVectorTy = typeConverter.convertType(vectorType);
76 if (!convertedVectorTy)
79 llvm::LLVMContext llvmContext;
89 MemRefType memrefType,
unsigned &align) {
90 Type elementTy = typeConverter.convertType(memrefType.getElementType());
96 llvm::LLVMContext llvmContext;
109 VectorType vectorType,
110 MemRefType memrefType,
unsigned &align,
111 bool useVectorAlignment) {
112 if (useVectorAlignment) {
127 if (!memRefType.isLastDimUnitStride())
137 MemRefType memRefType,
Value llvmMemref,
Value base,
140 "unsupported memref type");
141 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
145 vectorType.getScalableDims()[0]);
146 return LLVM::GEPOp::create(
147 rewriter, loc, ptrsType,
148 typeConverter.convertType(memRefType.getElementType()), base,
index);
155 if (
auto attr = dyn_cast<Attribute>(foldResult)) {
156 auto intAttr = cast<IntegerAttr>(attr);
157 return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
160 return cast<Value>(foldResult);
166using VectorScaleOpConversion =
170class VectorBitCastOpConversion
173 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
176 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter)
const override {
179 VectorType resultTy = bitCastOp.getResultVectorType();
180 if (resultTy.getRank() > 1)
182 Type newResultTy = typeConverter->convertType(resultTy);
183 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
184 adaptor.getOperands()[0]);
192static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193 vector::LoadOpAdaptor adaptor,
194 VectorType vectorTy,
Value ptr,
unsigned align,
195 ConversionPatternRewriter &rewriter) {
196 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy,
ptr, align,
198 loadOp.getNontemporal());
201static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202 vector::MaskedLoadOpAdaptor adaptor,
203 VectorType vectorTy,
Value ptr,
unsigned align,
204 ConversionPatternRewriter &rewriter) {
205 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206 loadOp, vectorTy,
ptr, adaptor.getMask(), adaptor.getPassThru(), align);
209static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210 vector::StoreOpAdaptor adaptor,
211 VectorType vectorTy,
Value ptr,
unsigned align,
212 ConversionPatternRewriter &rewriter) {
213 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
215 storeOp.getNontemporal());
218static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219 vector::MaskedStoreOpAdaptor adaptor,
220 VectorType vectorTy,
Value ptr,
unsigned align,
221 ConversionPatternRewriter &rewriter) {
222 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
223 storeOp, adaptor.getValueToStore(),
ptr, adaptor.getMask(), align);
228template <
class LoadOrStoreOp>
231 explicit VectorLoadStoreConversion(
const LLVMTypeConverter &typeConv,
233 : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
234 useVectorAlignment(useVectorAlign) {}
235 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
238 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239 typename LoadOrStoreOp::Adaptor adaptor,
240 ConversionPatternRewriter &rewriter)
const override {
242 VectorType vectorTy = loadOrStoreOp.getVectorType();
243 if (vectorTy.getRank() > 1)
246 auto loc = loadOrStoreOp->getLoc();
247 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
251 unsigned align = loadOrStoreOp.getAlignment().value_or(0);
254 memRefTy, align, useVectorAlignment)))
255 return rewriter.notifyMatchFailure(loadOrStoreOp,
256 "could not resolve alignment");
259 auto vtype = cast<VectorType>(
260 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
262 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
263 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
273 const bool useVectorAlignment;
277class VectorGatherOpConversion
280 explicit VectorGatherOpConversion(
const LLVMTypeConverter &typeConv,
282 : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
283 useVectorAlignment(useVectorAlign) {}
284 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
287 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
289 Location loc = gather->getLoc();
290 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
291 assert(memRefType &&
"The base should be bufferized");
295 return rewriter.notifyMatchFailure(gather,
"memref type not supported");
297 VectorType vType = gather.getVectorType();
298 if (vType.getRank() > 1) {
299 return rewriter.notifyMatchFailure(
300 gather,
"only 1-D vectors can be lowered to LLVM");
305 unsigned align = gather.getAlignment().value_or(0);
308 memRefType, align, useVectorAlignment)))
309 return rewriter.notifyMatchFailure(gather,
"could not resolve alignment");
313 adaptor.getBase(), adaptor.getOffsets());
314 Value base = adaptor.getBase();
316 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
317 base, ptr, adaptor.getIndices(), vType);
320 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
321 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
322 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
331 const bool useVectorAlignment;
335class VectorScatterOpConversion
338 explicit VectorScatterOpConversion(
const LLVMTypeConverter &typeConv,
340 : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
341 useVectorAlignment(useVectorAlign) {}
343 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
346 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
347 ConversionPatternRewriter &rewriter)
const override {
348 auto loc = scatter->getLoc();
349 auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
350 assert(memRefType &&
"The base should be bufferized");
354 return rewriter.notifyMatchFailure(scatter,
"memref type not supported");
356 VectorType vType = scatter.getVectorType();
357 if (vType.getRank() > 1) {
358 return rewriter.notifyMatchFailure(
359 scatter,
"only 1-D vectors can be lowered to LLVM");
364 unsigned align = scatter.getAlignment().value_or(0);
367 memRefType, align, useVectorAlignment)))
368 return rewriter.notifyMatchFailure(scatter,
369 "could not resolve alignment");
373 adaptor.getBase(), adaptor.getOffsets());
375 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
376 adaptor.getBase(), ptr, adaptor.getIndices(), vType);
379 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
380 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
381 rewriter.getI32IntegerAttr(align));
390 const bool useVectorAlignment;
394class VectorExpandLoadOpConversion
397 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
400 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
401 ConversionPatternRewriter &rewriter)
const override {
402 auto loc = expand->getLoc();
403 MemRefType memRefType = expand.getMemRefType();
406 auto vtype = typeConverter->convertType(expand.getVectorType());
408 adaptor.getBase(), adaptor.getIndices());
413 uint64_t alignment = expand.getAlignment().value_or(1);
415 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
416 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
423class VectorCompressStoreOpConversion
426 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
429 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
430 ConversionPatternRewriter &rewriter)
const override {
431 auto loc = compress->getLoc();
432 MemRefType memRefType = compress.getMemRefType();
436 adaptor.getBase(), adaptor.getIndices());
441 uint64_t alignment = compress.getAlignment().value_or(1);
443 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
444 compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
450class ReductionNeutralZero {};
451class ReductionNeutralIntOne {};
452class ReductionNeutralFPOne {};
453class ReductionNeutralAllOnes {};
454class ReductionNeutralSIntMin {};
455class ReductionNeutralUIntMin {};
456class ReductionNeutralSIntMax {};
457class ReductionNeutralUIntMax {};
458class ReductionNeutralFPMin {};
459class ReductionNeutralFPMax {};
463 ConversionPatternRewriter &rewriter,
465 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
466 rewriter.getZeroAttr(llvmType));
471 ConversionPatternRewriter &rewriter,
473 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
474 rewriter.getIntegerAttr(llvmType, 1));
479 ConversionPatternRewriter &rewriter,
481 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
482 rewriter.getFloatAttr(llvmType, 1.0));
487 ConversionPatternRewriter &rewriter,
489 return LLVM::ConstantOp::create(
490 rewriter, loc, llvmType,
491 rewriter.getIntegerAttr(
497 ConversionPatternRewriter &rewriter,
499 return LLVM::ConstantOp::create(
500 rewriter, loc, llvmType,
501 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
507 ConversionPatternRewriter &rewriter,
509 return LLVM::ConstantOp::create(
510 rewriter, loc, llvmType,
511 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
517 ConversionPatternRewriter &rewriter,
519 return LLVM::ConstantOp::create(
520 rewriter, loc, llvmType,
521 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
527 ConversionPatternRewriter &rewriter,
529 return LLVM::ConstantOp::create(
530 rewriter, loc, llvmType,
531 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
537 ConversionPatternRewriter &rewriter,
539 auto floatType = cast<FloatType>(llvmType);
540 return LLVM::ConstantOp::create(
541 rewriter, loc, llvmType,
542 rewriter.getFloatAttr(
543 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
549 ConversionPatternRewriter &rewriter,
551 auto floatType = cast<FloatType>(llvmType);
552 return LLVM::ConstantOp::create(
553 rewriter, loc, llvmType,
554 rewriter.getFloatAttr(
555 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
561template <
class ReductionNeutral>
562static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
575static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
577 VectorType vType = cast<VectorType>(llvmType);
578 auto vShape = vType.getShape();
579 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
581 Value baseVecLength = LLVM::ConstantOp::create(
582 rewriter, loc, rewriter.getI32Type(),
583 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
585 if (!vType.getScalableDims()[0])
586 return baseVecLength;
589 Value vScale = vector::VectorScaleOp::create(rewriter, loc);
591 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
592 Value scalableVecLength =
593 arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
594 return scalableVecLength;
601template <
class LLVMRedIntrinOp,
class ScalarOp>
602static Value createIntegerReductionArithmeticOpLowering(
603 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
607 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
610 result = ScalarOp::create(rewriter, loc, accumulator,
result);
618template <
class LLVMRedIntrinOp>
619static Value createIntegerReductionComparisonOpLowering(
620 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
621 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
623 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
626 LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator,
result);
627 result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator,
result);
633template <
typename Source>
634struct VectorToScalarMapper;
636struct VectorToScalarMapper<
LLVM::vector_reduce_fmaximum> {
637 using Type = LLVM::MaximumOp;
640struct VectorToScalarMapper<
LLVM::vector_reduce_fminimum> {
641 using Type = LLVM::MinimumOp;
644struct VectorToScalarMapper<
LLVM::vector_reduce_fmax> {
645 using Type = LLVM::MaxNumOp;
648struct VectorToScalarMapper<
LLVM::vector_reduce_fmin> {
649 using Type = LLVM::MinNumOp;
653template <
class LLVMRedIntrinOp>
654static Value createFPReductionComparisonOpLowering(
655 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
656 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
658 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
661 result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
662 rewriter, loc,
result, accumulator);
669class MaskNeutralFMaximum {};
670class MaskNeutralFMinimum {};
674getMaskNeutralValue(MaskNeutralFMaximum,
675 const llvm::fltSemantics &floatSemantics) {
676 return llvm::APFloat::getSmallest(floatSemantics,
true);
680getMaskNeutralValue(MaskNeutralFMinimum,
681 const llvm::fltSemantics &floatSemantics) {
682 return llvm::APFloat::getLargest(floatSemantics,
false);
686template <
typename MaskNeutral>
687static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
690 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
691 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
693 return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
700template <
class LLVMRedIntrinOp,
class MaskNeutral>
702lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
705 Value mask, LLVM::FastmathFlagsAttr fmf) {
706 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
707 rewriter, loc, llvmType, vectorOperand.
getType());
708 const Value selectedVectorByMask = LLVM::SelectOp::create(
709 rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
710 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
711 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
714template <
class LLVMRedIntrinOp,
class ReductionNeutral>
716lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc,
718 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
719 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
720 llvmType, accumulator);
721 return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
722 accumulator, vectorOperand,
729template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
731lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
734 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
735 llvmType, accumulator);
736 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
737 accumulator, vectorOperand);
740template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
741static Value lowerPredicatedReductionWithStartValue(
742 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
744 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
745 llvmType, accumulator);
747 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
748 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
749 accumulator, vectorOperand,
753template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
754 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
755static Value lowerPredicatedReductionWithStartValue(
756 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
759 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
760 IntReductionNeutral>(
761 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
764 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
766 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
770class VectorReductionOpConversion
773 explicit VectorReductionOpConversion(
const LLVMTypeConverter &typeConv,
774 bool reassociateFPRed)
775 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
776 reassociateFPReductions(reassociateFPRed) {}
779 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
780 ConversionPatternRewriter &rewriter)
const override {
781 auto kind = reductionOp.getKind();
782 Type eltType = reductionOp.getDest().getType();
783 Type llvmType = typeConverter->convertType(eltType);
784 Value operand = adaptor.getVector();
785 Value acc = adaptor.getAcc();
786 Location loc = reductionOp.getLoc();
792 case vector::CombiningKind::ADD:
794 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
796 rewriter, loc, llvmType, operand, acc);
798 case vector::CombiningKind::MUL:
800 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
802 rewriter, loc, llvmType, operand, acc);
804 case vector::CombiningKind::MINUI:
805 result = createIntegerReductionComparisonOpLowering<
806 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
807 LLVM::ICmpPredicate::ule);
809 case vector::CombiningKind::MINSI:
810 result = createIntegerReductionComparisonOpLowering<
811 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
812 LLVM::ICmpPredicate::sle);
814 case vector::CombiningKind::MAXUI:
815 result = createIntegerReductionComparisonOpLowering<
816 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
817 LLVM::ICmpPredicate::uge);
819 case vector::CombiningKind::MAXSI:
820 result = createIntegerReductionComparisonOpLowering<
821 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
822 LLVM::ICmpPredicate::sge);
824 case vector::CombiningKind::AND:
826 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
828 rewriter, loc, llvmType, operand, acc);
830 case vector::CombiningKind::OR:
832 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
834 rewriter, loc, llvmType, operand, acc);
836 case vector::CombiningKind::XOR:
838 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
840 rewriter, loc, llvmType, operand, acc);
845 rewriter.replaceOp(reductionOp,
result);
850 if (!isa<FloatType>(eltType))
853 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
854 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
855 reductionOp.getContext(),
857 fmf = LLVM::FastmathFlagsAttr::get(
858 reductionOp.getContext(),
859 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
860 : LLVM::FastmathFlags::none));
864 if (kind == vector::CombiningKind::ADD) {
865 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
866 ReductionNeutralZero>(
867 rewriter, loc, llvmType, operand, acc, fmf);
868 }
else if (kind == vector::CombiningKind::MUL) {
869 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
870 ReductionNeutralFPOne>(
871 rewriter, loc, llvmType, operand, acc, fmf);
872 }
else if (kind == vector::CombiningKind::MINIMUMF) {
874 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
875 rewriter, loc, llvmType, operand, acc, fmf);
876 }
else if (kind == vector::CombiningKind::MAXIMUMF) {
878 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
879 rewriter, loc, llvmType, operand, acc, fmf);
880 }
else if (kind == vector::CombiningKind::MINNUMF) {
881 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
882 rewriter, loc, llvmType, operand, acc, fmf);
883 }
else if (kind == vector::CombiningKind::MAXNUMF) {
884 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
885 rewriter, loc, llvmType, operand, acc, fmf);
890 rewriter.replaceOp(reductionOp,
result);
895 const bool reassociateFPReductions;
906template <
class MaskedOp>
907class VectorMaskOpConversionBase
910 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
913 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
914 ConversionPatternRewriter &rewriter)
const final {
916 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
919 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
923 virtual LogicalResult
924 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
925 vector::MaskableOpInterface maskableOp,
926 ConversionPatternRewriter &rewriter)
const = 0;
929class MaskedReductionOpConversion
930 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
933 using VectorMaskOpConversionBase<
934 vector::ReductionOp>::VectorMaskOpConversionBase;
936 LogicalResult matchAndRewriteMaskableOp(
937 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
938 ConversionPatternRewriter &rewriter)
const override {
939 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
940 auto kind = reductionOp.getKind();
941 Type eltType = reductionOp.getDest().getType();
942 Type llvmType = typeConverter->convertType(eltType);
943 Value operand = reductionOp.getVector();
944 Value acc = reductionOp.getAcc();
945 Location loc = reductionOp.getLoc();
947 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
948 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
949 reductionOp.getContext(),
954 case vector::CombiningKind::ADD:
955 result = lowerPredicatedReductionWithStartValue<
956 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
957 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
960 case vector::CombiningKind::MUL:
961 result = lowerPredicatedReductionWithStartValue<
962 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
963 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
966 case vector::CombiningKind::MINUI:
967 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
968 ReductionNeutralUIntMax>(
969 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
971 case vector::CombiningKind::MINSI:
972 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
973 ReductionNeutralSIntMax>(
974 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
976 case vector::CombiningKind::MAXUI:
977 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
978 ReductionNeutralUIntMin>(
979 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
981 case vector::CombiningKind::MAXSI:
982 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
983 ReductionNeutralSIntMin>(
984 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
986 case vector::CombiningKind::AND:
987 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
988 ReductionNeutralAllOnes>(
989 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
991 case vector::CombiningKind::OR:
992 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
993 ReductionNeutralZero>(
994 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
996 case vector::CombiningKind::XOR:
997 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
998 ReductionNeutralZero>(
999 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1001 case vector::CombiningKind::MINNUMF:
1002 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1003 ReductionNeutralFPMax>(
1004 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1006 case vector::CombiningKind::MAXNUMF:
1007 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1008 ReductionNeutralFPMin>(
1009 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1011 case CombiningKind::MAXIMUMF:
1012 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1013 MaskNeutralFMaximum>(
1014 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1016 case CombiningKind::MINIMUMF:
1017 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1018 MaskNeutralFMinimum>(
1019 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1024 rewriter.replaceOp(maskOp,
result);
1029class VectorShuffleOpConversion
1032 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
1035 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1036 ConversionPatternRewriter &rewriter)
const override {
1037 auto loc = shuffleOp->getLoc();
1038 auto v1Type = shuffleOp.getV1VectorType();
1039 auto v2Type = shuffleOp.getV2VectorType();
1040 auto vectorType = shuffleOp.getResultVectorType();
1041 Type llvmType = typeConverter->convertType(vectorType);
1042 ArrayRef<int64_t> mask = shuffleOp.getMask();
1049 int64_t rank = vectorType.getRank();
1051 bool wellFormed0DCase =
1052 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1053 bool wellFormedNDCase =
1054 v1Type.getRank() == rank && v2Type.getRank() == rank;
1055 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1060 if (rank <= 1 && v1Type == v2Type) {
1061 Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1062 rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1063 llvm::to_vector_of<int32_t>(mask));
1064 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1069 int64_t v1Dim = v1Type.getDimSize(0);
1071 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1072 eltType = arrayType.getElementType();
1074 eltType = cast<VectorType>(llvmType).getElementType();
1075 Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1077 for (int64_t extPos : mask) {
1078 Value value = adaptor.getV1();
1079 if (extPos >= v1Dim) {
1081 value = adaptor.getV2();
1083 Value extract =
extractOne(rewriter, *getTypeConverter(), loc, value,
1084 eltType, rank, extPos);
1085 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1086 llvmType, rank, insPos++);
1088 rewriter.replaceOp(shuffleOp, insert);
1093class VectorExtractOpConversion
1096 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1099 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1100 ConversionPatternRewriter &rewriter)
const override {
1101 auto loc = extractOp->getLoc();
1102 auto resultType = extractOp.getResult().getType();
1103 auto llvmResultType = typeConverter->convertType(resultType);
1105 if (!llvmResultType)
1109 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1123 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1127 bool extractsScalar =
static_cast<int64_t
>(positionVec.size()) ==
1128 extractOp.getSourceVectorType().getRank();
1132 if (extractOp.getSourceVectorType().getRank() == 0) {
1133 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1134 positionVec.push_back(rewriter.getZeroAttr(idxType));
1137 Value extracted = adaptor.getSource();
1138 if (extractsAggregate) {
1139 ArrayRef<OpFoldResult> position(positionVec);
1140 if (extractsScalar) {
1144 position = position.drop_back();
1147 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1150 extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1154 if (extractsScalar) {
1155 extracted = LLVM::ExtractElementOp::create(
1156 rewriter, loc, extracted,
1160 rewriter.replaceOp(extractOp, extracted);
1181 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1184 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1185 ConversionPatternRewriter &rewriter)
const override {
1186 VectorType vType = fmaOp.getVectorType();
1187 if (vType.getRank() > 1)
1190 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1191 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1196class VectorInsertOpConversion
1199 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1202 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1203 ConversionPatternRewriter &rewriter)
const override {
1204 auto loc = insertOp->getLoc();
1205 auto destVectorType = insertOp.getDestVectorType();
1206 auto llvmResultType = typeConverter->convertType(destVectorType);
1208 if (!llvmResultType)
1212 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1234 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1236 bool insertIntoInnermostDim =
1237 static_cast<int64_t
>(positionVec.size()) == destVectorType.getRank();
1239 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1240 positionVec.begin(),
1241 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1242 OpFoldResult positionOfScalarWithin1DVector;
1243 if (destVectorType.getRank() == 0) {
1246 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1247 positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1248 }
else if (insertIntoInnermostDim) {
1249 positionOfScalarWithin1DVector = positionVec.back();
1255 Value sourceAggregate = adaptor.getValueToStore();
1256 if (insertIntoInnermostDim) {
1259 if (isNestedAggregate) {
1262 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1263 llvm::IsaPred<Attribute>)) {
1267 sourceAggregate = LLVM::ExtractValueOp::create(
1268 rewriter, loc, adaptor.getDest(),
1273 sourceAggregate = adaptor.getDest();
1276 sourceAggregate = LLVM::InsertElementOp::create(
1277 rewriter, loc, sourceAggregate.
getType(), sourceAggregate,
1278 adaptor.getValueToStore(),
1282 Value
result = sourceAggregate;
1283 if (isNestedAggregate) {
1284 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1285 llvm::IsaPred<Attribute>)) {
1289 result = LLVM::InsertValueOp::create(
1290 rewriter, loc, adaptor.getDest(), sourceAggregate,
1294 rewriter.replaceOp(insertOp,
result);
1300struct VectorScalableInsertOpLowering
1302 using ConvertOpToLLVMPattern<
1303 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1306 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1307 ConversionPatternRewriter &rewriter)
const override {
1308 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1309 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1315struct VectorScalableExtractOpLowering
1317 using ConvertOpToLLVMPattern<
1318 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1321 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1322 ConversionPatternRewriter &rewriter)
const override {
1323 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1324 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1325 adaptor.getSource(), adaptor.getPos());
1358 setHasBoundedRewriteRecursion();
1361 LogicalResult matchAndRewrite(FMAOp op,
1362 PatternRewriter &rewriter)
const override {
1363 auto vType = op.getVectorType();
1364 if (vType.getRank() < 2)
1367 auto loc = op.getLoc();
1368 auto elemType = vType.getElementType();
1369 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1371 Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1372 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1373 Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1374 Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1375 Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1376 Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1377 desc = InsertOp::create(rewriter, loc, fma, desc, i);
1386static std::optional<SmallVector<int64_t, 4>>
1387computeContiguousStrides(MemRefType memRefType) {
1390 if (
failed(memRefType.getStridesAndOffset(strides, offset)))
1391 return std::nullopt;
1392 if (!strides.empty() && strides.back() != 1)
1393 return std::nullopt;
1395 if (memRefType.getLayout().isIdentity())
1402 auto sizes = memRefType.getShape();
1404 if (ShapedType::isDynamic(sizes[
index + 1]) ||
1405 ShapedType::isDynamic(strides[
index]) ||
1406 ShapedType::isDynamic(strides[
index + 1]))
1407 return std::nullopt;
1409 return std::nullopt;
1414class VectorTypeCastOpConversion
1417 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1420 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1421 ConversionPatternRewriter &rewriter)
const override {
1422 auto loc = castOp->getLoc();
1423 MemRefType sourceMemRefType =
1424 cast<MemRefType>(castOp.getOperand().getType());
1425 MemRefType targetMemRefType = castOp.getType();
1428 if (!sourceMemRefType.hasStaticShape() ||
1429 !targetMemRefType.hasStaticShape())
1432 auto llvmSourceDescriptorTy =
1433 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1434 if (!llvmSourceDescriptorTy)
1436 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1438 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1439 typeConverter->convertType(targetMemRefType));
1440 if (!llvmTargetDescriptorTy)
1444 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1447 auto targetStrides = computeContiguousStrides(targetMemRefType);
1451 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1454 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1457 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1459 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1460 desc.setAllocatedPtr(rewriter, loc, allocated);
1463 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1464 desc.setAlignedPtr(rewriter, loc, ptr);
1466 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1467 auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1468 desc.setOffset(rewriter, loc, zero);
1471 for (
const auto &indexedSize :
1472 llvm::enumerate(targetMemRefType.getShape())) {
1473 int64_t index = indexedSize.index();
1475 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1476 auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1477 desc.setSize(rewriter, loc, index, size);
1478 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1479 (*targetStrides)[index]);
1481 LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1482 desc.setStride(rewriter, loc, index, stride);
1485 rewriter.replaceOp(castOp, {desc});
1492class VectorCreateMaskOpConversion
1493 :
public OpConversionPattern<vector::CreateMaskOp> {
1495 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1496 bool enableIndexOpt)
1497 : OpConversionPattern<vector::CreateMaskOp>(context),
1498 force32BitVectorIndices(enableIndexOpt) {}
1501 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1502 ConversionPatternRewriter &rewriter)
const override {
1503 auto dstType = op.getType();
1504 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1506 IntegerType idxType =
1507 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1508 auto loc = op->getLoc();
1509 Value
indices = LLVM::StepVectorOp::create(
1511 LLVM::getVectorType(idxType, dstType.getShape()[0],
1513 Value maskBound = adaptor.getOperands()[0];
1520 if (force32BitVectorIndices) {
1523 maskBound = arith::MinSIOp::create(rewriter, loc, maskBound, maxBound);
1527 Value bounds = BroadcastOp::create(rewriter, loc,
indices.getType(), bound);
1528 Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1530 rewriter.replaceOp(op, comp);
1535 const bool force32BitVectorIndices;
1539 SymbolTableCollection *symbolTables =
nullptr;
1542 explicit VectorPrintOpConversion(
1543 const LLVMTypeConverter &typeConverter,
1544 SymbolTableCollection *symbolTables =
nullptr)
1545 : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1546 symbolTables(symbolTables) {}
1562 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1563 ConversionPatternRewriter &rewriter)
const override {
1564 auto parent =
printOp->getParentOfType<ModuleOp>();
1570 if (
auto value = adaptor.getSource()) {
1572 if (isa<VectorType>(printType)) {
1576 if (
failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1580 auto punct =
printOp.getPunctuation();
1581 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1583 LLVM::createPrintStrCall(rewriter, loc, parent,
"vector_print_str",
1584 *stringLiteral, *getTypeConverter(),
1586 if (createResult.failed())
1589 }
else if (punct != PrintPunctuation::NoPunctuation) {
1590 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1592 case PrintPunctuation::Close:
1593 return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
1595 case PrintPunctuation::Open:
1596 return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
1598 case PrintPunctuation::Comma:
1599 return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
1601 case PrintPunctuation::NewLine:
1602 return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
1605 llvm_unreachable(
"unexpected punctuation");
1610 emitCall(rewriter,
printOp->getLoc(), op.value());
1618 enum class PrintConversion {
1627 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1628 ModuleOp parent, Location loc, Type printType,
1629 Value value)
const {
1630 if (typeConverter->convertType(printType) ==
nullptr)
1634 PrintConversion conversion = PrintConversion::None;
1635 FailureOr<Operation *> printer;
1637 printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
1639 printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
1641 conversion = PrintConversion::Bitcast16;
1642 printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
1644 conversion = PrintConversion::Bitcast16;
1645 printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
1647 printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1648 }
else if (
auto intTy = dyn_cast<IntegerType>(printType)) {
1652 unsigned width = intTy.getWidth();
1653 if (intTy.isUnsigned()) {
1656 conversion = PrintConversion::ZeroExt64;
1658 LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1663 assert(intTy.isSignless() || intTy.isSigned());
1668 conversion = PrintConversion::ZeroExt64;
1669 else if (width < 64)
1670 conversion = PrintConversion::SignExt64;
1672 LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
1677 }
else if (
auto floatTy = dyn_cast<FloatType>(printType)) {
1680 llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
1681 Value semValue = LLVM::ConstantOp::create(
1682 rewriter, loc, rewriter.getI32Type(),
1683 rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
1685 LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
1687 LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
1688 emitCall(rewriter, loc, printer.value(),
1697 switch (conversion) {
1698 case PrintConversion::ZeroExt64:
1699 value = arith::ExtUIOp::create(
1700 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1702 case PrintConversion::SignExt64:
1703 value = arith::ExtSIOp::create(
1704 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1706 case PrintConversion::Bitcast16:
1707 value = LLVM::BitcastOp::create(
1708 rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
1710 case PrintConversion::None:
1713 emitCall(rewriter, loc, printer.value(), value);
1718 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1720 LLVM::CallOp::create(rewriter, loc,
TypeRange(), SymbolRefAttr::get(ref),
1728struct VectorBroadcastScalarToLowRankLowering
1730 using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
1733 matchAndRewrite(vector::BroadcastOp
broadcast, OpAdaptor adaptor,
1734 ConversionPatternRewriter &rewriter)
const override {
1735 if (isa<VectorType>(
broadcast.getSourceType()))
1736 return rewriter.notifyMatchFailure(
1737 broadcast,
"broadcast from vector type not handled");
1740 if (resultType.getRank() > 1)
1741 return rewriter.notifyMatchFailure(
broadcast,
1742 "broadcast to 2+-d handled elsewhere");
1748 auto zero = LLVM::ConstantOp::create(
1750 typeConverter->convertType(rewriter.getIntegerType(32)),
1751 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1754 if (resultType.getRank() == 0) {
1755 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1756 broadcast, vectorType, poison, adaptor.getSource(), zero);
1761 LLVM::InsertElementOp::create(rewriter,
broadcast.
getLoc(), vectorType,
1762 poison, adaptor.getSource(), zero);
1766 SmallVector<int32_t> zeroValues(width, 0);
1769 auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
1780struct VectorBroadcastScalarToNdLowering
1782 using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
1785 matchAndRewrite(BroadcastOp
broadcast, OpAdaptor adaptor,
1786 ConversionPatternRewriter &rewriter)
const override {
1787 if (isa<VectorType>(
broadcast.getSourceType()))
1788 return rewriter.notifyMatchFailure(
1789 broadcast,
"broadcast from vector type not handled");
1792 if (resultType.getRank() <= 1)
1793 return rewriter.notifyMatchFailure(
1794 broadcast,
"broadcast to 1-d or 0-d handled elsewhere");
1798 auto vectorTypeInfo =
1800 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1801 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1802 if (!llvmNDVectorTy || !llvm1DVectorTy)
1806 Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1810 Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1811 auto zero = LLVM::ConstantOp::create(
1812 rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1813 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1814 Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1815 vdesc, adaptor.getSource(), zero);
1818 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1819 SmallVector<int32_t> zeroValues(width, 0);
1820 v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1824 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1825 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1834struct VectorInterleaveOpLowering
1839 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1840 ConversionPatternRewriter &rewriter)
const override {
1841 VectorType resultType = interleaveOp.getResultVectorType();
1843 if (resultType.getRank() != 1)
1844 return rewriter.notifyMatchFailure(interleaveOp,
1845 "InterleaveOp not rank 1");
1847 if (resultType.isScalable()) {
1848 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1849 interleaveOp, typeConverter->convertType(resultType),
1850 adaptor.getLhs(), adaptor.getRhs());
1857 int64_t resultVectorSize = resultType.getNumElements();
1858 SmallVector<int32_t> interleaveShuffleMask;
1859 interleaveShuffleMask.reserve(resultVectorSize);
1860 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1861 interleaveShuffleMask.push_back(i);
1862 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1864 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1865 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1866 interleaveShuffleMask);
1873struct VectorDeinterleaveOpLowering
1878 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1879 ConversionPatternRewriter &rewriter)
const override {
1880 VectorType resultType = deinterleaveOp.getResultVectorType();
1881 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1882 auto loc = deinterleaveOp.getLoc();
1886 if (resultType.getRank() != 1)
1887 return rewriter.notifyMatchFailure(deinterleaveOp,
1888 "DeinterleaveOp not rank 1");
1890 if (resultType.isScalable()) {
1891 const auto *llvmTypeConverter = this->getTypeConverter();
1892 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1893 auto packedOpResults =
1894 llvmTypeConverter->packOperationResults(deinterleaveResults);
1895 auto intrinsic = LLVM::vector_deinterleave2::create(
1896 rewriter, loc, packedOpResults, adaptor.getSource());
1898 auto evenResult = LLVM::ExtractValueOp::create(
1899 rewriter, loc, intrinsic->getResult(0), 0);
1900 auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1901 intrinsic->getResult(0), 1);
1903 rewriter.replaceOp(deinterleaveOp,
ValueRange{evenResult, oddResult});
1910 int64_t resultVectorSize = resultType.getNumElements();
1911 SmallVector<int32_t> evenShuffleMask;
1912 SmallVector<int32_t> oddShuffleMask;
1914 evenShuffleMask.reserve(resultVectorSize);
1915 oddShuffleMask.reserve(resultVectorSize);
1917 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1919 evenShuffleMask.push_back(i);
1921 oddShuffleMask.push_back(i);
1924 auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1925 auto evenShuffle = LLVM::ShuffleVectorOp::create(
1926 rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1927 auto oddShuffle = LLVM::ShuffleVectorOp::create(
1928 rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1930 rewriter.replaceOp(deinterleaveOp,
ValueRange{evenShuffle, oddShuffle});
1936struct VectorFromElementsLowering
1941 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1942 ConversionPatternRewriter &rewriter)
const override {
1943 Location loc = fromElementsOp.getLoc();
1944 VectorType vectorType = fromElementsOp.getType();
1948 if (vectorType.getRank() > 1)
1949 return rewriter.notifyMatchFailure(fromElementsOp,
1950 "rank > 1 vectors are not supported");
1951 Type llvmType = typeConverter->convertType(vectorType);
1952 Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1953 Value
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1954 for (
auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1956 LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1957 result = LLVM::InsertElementOp::create(rewriter, loc, llvmType,
result,
1960 rewriter.replaceOp(fromElementsOp,
result);
1966struct VectorToElementsLowering
1971 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1972 ConversionPatternRewriter &rewriter)
const override {
1973 Location loc = toElementsOp.getLoc();
1974 auto idxType = typeConverter->convertType(rewriter.getIndexType());
1975 Value source = adaptor.getSource();
1977 SmallVector<Value> results(toElementsOp->getNumResults());
1978 for (
auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1980 if (element.use_empty())
1983 auto constIdx = LLVM::ConstantOp::create(
1984 rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
1985 auto llvmType = typeConverter->convertType(element.getType());
1987 Value
result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1992 rewriter.replaceOp(toElementsOp, results);
1998struct VectorScalableStepOpLowering
2003 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
2004 ConversionPatternRewriter &rewriter)
const override {
2005 auto resultType = cast<VectorType>(stepOp.getType());
2006 if (!resultType.isScalable()) {
2009 Type llvmType = typeConverter->convertType(stepOp.getType());
2010 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
2025class ContractionOpToMatmulOpLowering
2028 using MaskableOpRewritePattern::MaskableOpRewritePattern;
2030 ContractionOpToMatmulOpLowering(MLIRContext *context,
2031 PatternBenefit benefit = 100)
2032 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2035 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2036 PatternRewriter &rewriter)
const override;
2056FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2057 vector::ContractionOp op, MaskingOpInterface maskOp,
2063 auto iteratorTypes = op.getIteratorTypes().getValue();
2069 Type opResType = op.getType();
2070 VectorType vecType = dyn_cast<VectorType>(opResType);
2071 if (vecType && vecType.isScalable()) {
2076 Type elementType = op.getLhsType().getElementType();
2080 Type dstElementType = vecType ? vecType.getElementType() : opResType;
2081 if (elementType != dstElementType)
2086 MLIRContext *ctx = op.getContext();
2087 Location loc = op.getLoc();
2091 Value
lhs = op.getLhs();
2092 auto lhsMap = op.getIndexingMapsArray()[0];
2094 lhs = vector::TransposeOp::create(rew, loc,
lhs, ArrayRef<int64_t>{1, 0});
2099 Value
rhs = op.getRhs();
2100 auto rhsMap = op.getIndexingMapsArray()[1];
2102 rhs = vector::TransposeOp::create(rew, loc,
rhs, ArrayRef<int64_t>{1, 0});
2107 VectorType lhsType = cast<VectorType>(
lhs.getType());
2108 VectorType rhsType = cast<VectorType>(
rhs.getType());
2109 int64_t lhsRows = lhsType.getDimSize(0);
2110 int64_t lhsColumns = lhsType.getDimSize(1);
2111 int64_t rhsColumns = rhsType.getDimSize(1);
2113 Type flattenedLHSType =
2114 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2115 lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType,
lhs);
2117 Type flattenedRHSType =
2118 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2119 rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType,
rhs);
2121 Value
mul = LLVM::MatrixMultiplyOp::create(
2123 VectorType::get(lhsRows * rhsColumns,
2124 cast<VectorType>(
lhs.getType()).getElementType()),
2125 lhs,
rhs, lhsRows, lhsColumns, rhsColumns);
2127 mul = vector::ShapeCastOp::create(
2129 VectorType::get({lhsRows, rhsColumns},
2134 auto accMap = op.getIndexingMapsArray()[2];
2136 mul = vector::TransposeOp::create(rew, loc,
mul, ArrayRef<int64_t>{1, 0});
2138 llvm_unreachable(
"invalid contraction semantics");
2140 Value res = isa<IntegerType>(elementType)
2141 ?
static_cast<Value
>(
2142 arith::AddIOp::create(rew, loc, op.getAcc(),
mul))
2143 : static_cast<Value>(
2144 arith::AddFOp::create(rew, loc, op.getAcc(),
mul));
2162class TransposeOpToMatrixTransposeOpLowering
2163 :
public OpRewritePattern<vector::TransposeOp> {
2167 LogicalResult matchAndRewrite(vector::TransposeOp op,
2168 PatternRewriter &rewriter)
const override {
2169 auto loc = op.getLoc();
2171 Value input = op.getVector();
2172 VectorType inputType = op.getSourceVectorType();
2173 VectorType resType = op.getResultVectorType();
2175 if (inputType.isScalable())
2177 op,
"This lowering does not support scalable vectors");
2180 ArrayRef<int64_t> transp = op.getPermutation();
2182 if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2186 Type flattenedType =
2187 VectorType::get(resType.getNumElements(), resType.getElementType());
2189 vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2192 Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2193 matrix, rows, columns);
2203 patterns.
add<VectorFMAOpNDRewritePattern>(patterns.
getContext());
2208 patterns.
add<ContractionOpToMatmulOpLowering>(patterns.
getContext(), benefit);
2213 patterns.
add<TransposeOpToMatrixTransposeOpLowering>(patterns.
getContext(),
2220 bool reassociateFPReductions,
bool force32BitVectorIndices,
2221 bool useVectorAlignment) {
2224 patterns.
add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2225 patterns.
add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2226 patterns.
add<VectorLoadStoreConversion<vector::LoadOp>,
2227 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2228 VectorLoadStoreConversion<vector::StoreOp>,
2229 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2230 VectorGatherOpConversion, VectorScatterOpConversion>(
2231 converter, useVectorAlignment);
2232 patterns.
add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2233 VectorExtractOpConversion, VectorFMAOp1DConversion,
2234 VectorInsertOpConversion, VectorPrintOpConversion,
2235 VectorTypeCastOpConversion, VectorScaleOpConversion,
2236 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2237 VectorBroadcastScalarToLowRankLowering,
2238 VectorBroadcastScalarToNdLowering,
2239 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2240 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2241 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2242 VectorToElementsLowering, VectorScalableStepOpLowering>(
2247struct VectorToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
2248 VectorToLLVMDialectInterface(
Dialect *dialect)
2249 : ConvertToLLVMPatternInterface(dialect) {}
2251 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
2252 void loadDependentDialects(MLIRContext *context)
const final {
2253 context->loadDialect<LLVM::LLVMDialect>();
2258 void populateConvertToLLVMConversionPatterns(
2259 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
2260 RewritePatternSet &patterns)
const final {
2269 dialect->addInterfaces<VectorToLLVMDialectInterface>();
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Conversion from types to the LLVM IR dialect.
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.