31#include "llvm/ADT/APFloat.h"
32#include "llvm/IR/LLVMContext.h"
33#include "llvm/Support/Casting.h"
45 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
47 auto idxType = rewriter.getIndexType();
48 auto constant = LLVM::ConstantOp::create(
49 rewriter, loc, typeConverter.convertType(idxType),
50 rewriter.getIntegerAttr(idxType, pos));
51 return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
54 return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
62 auto idxType = rewriter.getIndexType();
63 auto constant = LLVM::ConstantOp::create(
64 rewriter, loc, typeConverter.convertType(idxType),
65 rewriter.getIntegerAttr(idxType, pos));
66 return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
69 return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
74 VectorType vectorType,
unsigned &align) {
75 Type convertedVectorTy = typeConverter.convertType(vectorType);
76 if (!convertedVectorTy)
79 llvm::LLVMContext llvmContext;
89 MemRefType memrefType,
unsigned &align) {
90 Type elementTy = typeConverter.convertType(memrefType.getElementType());
96 llvm::LLVMContext llvmContext;
109 VectorType vectorType,
110 MemRefType memrefType,
unsigned &align,
111 bool useVectorAlignment) {
112 if (useVectorAlignment) {
127 if (!memRefType.isLastDimUnitStride())
137 MemRefType memRefType,
Value llvmMemref,
Value base,
140 "unsupported memref type");
141 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
145 vectorType.getScalableDims()[0]);
146 return LLVM::GEPOp::create(
147 rewriter, loc, ptrsType,
148 typeConverter.convertType(memRefType.getElementType()), base,
index);
155 if (
auto attr = dyn_cast<Attribute>(foldResult)) {
156 auto intAttr = cast<IntegerAttr>(attr);
157 return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
160 return cast<Value>(foldResult);
166using VectorScaleOpConversion =
170class VectorBitCastOpConversion
173 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
176 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter)
const override {
179 VectorType resultTy = bitCastOp.getResultVectorType();
180 if (resultTy.getRank() > 1)
182 Type newResultTy = typeConverter->convertType(resultTy);
183 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
184 adaptor.getOperands()[0]);
192static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193 vector::LoadOpAdaptor adaptor,
194 VectorType vectorTy,
Value ptr,
unsigned align,
195 ConversionPatternRewriter &rewriter) {
196 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy,
ptr, align,
198 loadOp.getNontemporal());
201static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202 vector::MaskedLoadOpAdaptor adaptor,
203 VectorType vectorTy,
Value ptr,
unsigned align,
204 ConversionPatternRewriter &rewriter) {
205 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206 loadOp, vectorTy,
ptr, adaptor.getMask(), adaptor.getPassThru(), align);
209static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210 vector::StoreOpAdaptor adaptor,
211 VectorType vectorTy,
Value ptr,
unsigned align,
212 ConversionPatternRewriter &rewriter) {
213 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
215 storeOp.getNontemporal());
218static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219 vector::MaskedStoreOpAdaptor adaptor,
220 VectorType vectorTy,
Value ptr,
unsigned align,
221 ConversionPatternRewriter &rewriter) {
222 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
223 storeOp, adaptor.getValueToStore(),
ptr, adaptor.getMask(), align);
228template <
class LoadOrStoreOp>
231 explicit VectorLoadStoreConversion(
const LLVMTypeConverter &typeConv,
233 : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
234 useVectorAlignment(useVectorAlign) {}
235 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
238 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239 typename LoadOrStoreOp::Adaptor adaptor,
240 ConversionPatternRewriter &rewriter)
const override {
242 VectorType vectorTy = loadOrStoreOp.getVectorType();
243 if (vectorTy.getRank() > 1)
246 auto loc = loadOrStoreOp->getLoc();
247 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
251 unsigned align = loadOrStoreOp.getAlignment().value_or(0);
254 memRefTy, align, useVectorAlignment)))
255 return rewriter.notifyMatchFailure(loadOrStoreOp,
256 "could not resolve alignment");
259 auto vtype = cast<VectorType>(
260 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
262 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
263 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
273 const bool useVectorAlignment;
277class VectorGatherOpConversion
280 explicit VectorGatherOpConversion(
const LLVMTypeConverter &typeConv,
282 : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
283 useVectorAlignment(useVectorAlign) {}
284 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
287 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
289 Location loc = gather->getLoc();
290 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
291 assert(memRefType &&
"The base should be bufferized");
294 return rewriter.notifyMatchFailure(gather,
"memref type not supported");
296 VectorType vType = gather.getVectorType();
297 if (vType.getRank() > 1) {
298 return rewriter.notifyMatchFailure(
299 gather,
"only 1-D vectors can be lowered to LLVM");
304 unsigned align = gather.getAlignment().value_or(0);
307 memRefType, align, useVectorAlignment)))
308 return rewriter.notifyMatchFailure(gather,
"could not resolve alignment");
312 adaptor.getBase(), adaptor.getOffsets());
313 Value base = adaptor.getBase();
315 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
316 base, ptr, adaptor.getIndices(), vType);
319 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
320 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
321 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
330 const bool useVectorAlignment;
334class VectorScatterOpConversion
337 explicit VectorScatterOpConversion(
const LLVMTypeConverter &typeConv,
339 : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
340 useVectorAlignment(useVectorAlign) {}
342 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
345 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter)
const override {
347 auto loc = scatter->getLoc();
348 auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
349 assert(memRefType &&
"The base should be bufferized");
352 return rewriter.notifyMatchFailure(scatter,
"memref type not supported");
354 VectorType vType = scatter.getVectorType();
355 if (vType.getRank() > 1) {
356 return rewriter.notifyMatchFailure(
357 scatter,
"only 1-D vectors can be lowered to LLVM");
362 unsigned align = scatter.getAlignment().value_or(0);
365 memRefType, align, useVectorAlignment)))
366 return rewriter.notifyMatchFailure(scatter,
367 "could not resolve alignment");
371 adaptor.getBase(), adaptor.getOffsets());
373 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
374 adaptor.getBase(), ptr, adaptor.getIndices(), vType);
377 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
378 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
379 rewriter.getI32IntegerAttr(align));
388 const bool useVectorAlignment;
392class VectorExpandLoadOpConversion
395 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
398 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter)
const override {
400 auto loc = expand->getLoc();
401 MemRefType memRefType = expand.getMemRefType();
404 auto vtype = typeConverter->convertType(expand.getVectorType());
406 adaptor.getBase(), adaptor.getIndices());
411 uint64_t alignment = expand.getAlignment().value_or(1);
413 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
414 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
421class VectorCompressStoreOpConversion
424 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
427 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter)
const override {
429 auto loc = compress->getLoc();
430 MemRefType memRefType = compress.getMemRefType();
434 adaptor.getBase(), adaptor.getIndices());
439 uint64_t alignment = compress.getAlignment().value_or(1);
441 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
442 compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
448class ReductionNeutralZero {};
449class ReductionNeutralIntOne {};
450class ReductionNeutralFPOne {};
451class ReductionNeutralAllOnes {};
452class ReductionNeutralSIntMin {};
453class ReductionNeutralUIntMin {};
454class ReductionNeutralSIntMax {};
455class ReductionNeutralUIntMax {};
456class ReductionNeutralFPMin {};
457class ReductionNeutralFPMax {};
460static Value createReductionNeutralValue(ReductionNeutralZero neutral,
461 ConversionPatternRewriter &rewriter,
463 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
464 rewriter.getZeroAttr(llvmType));
468static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
469 ConversionPatternRewriter &rewriter,
471 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
472 rewriter.getIntegerAttr(llvmType, 1));
476static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
477 ConversionPatternRewriter &rewriter,
479 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
480 rewriter.getFloatAttr(llvmType, 1.0));
484static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
485 ConversionPatternRewriter &rewriter,
487 return LLVM::ConstantOp::create(
488 rewriter, loc, llvmType,
489 rewriter.getIntegerAttr(
494static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
495 ConversionPatternRewriter &rewriter,
497 return LLVM::ConstantOp::create(
498 rewriter, loc, llvmType,
499 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
504static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
505 ConversionPatternRewriter &rewriter,
507 return LLVM::ConstantOp::create(
508 rewriter, loc, llvmType,
509 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
514static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
515 ConversionPatternRewriter &rewriter,
517 return LLVM::ConstantOp::create(
518 rewriter, loc, llvmType,
519 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
524static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
525 ConversionPatternRewriter &rewriter,
527 return LLVM::ConstantOp::create(
528 rewriter, loc, llvmType,
529 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
534static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
535 ConversionPatternRewriter &rewriter,
537 auto floatType = cast<FloatType>(llvmType);
538 return LLVM::ConstantOp::create(
539 rewriter, loc, llvmType,
540 rewriter.getFloatAttr(
541 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
546static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
547 ConversionPatternRewriter &rewriter,
549 auto floatType = cast<FloatType>(llvmType);
550 return LLVM::ConstantOp::create(
551 rewriter, loc, llvmType,
552 rewriter.getFloatAttr(
553 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
559template <
class ReductionNeutral>
560static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
566 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
573static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
575 VectorType vType = cast<VectorType>(llvmType);
576 auto vShape = vType.getShape();
577 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
579 Value baseVecLength = LLVM::ConstantOp::create(
580 rewriter, loc, rewriter.getI32Type(),
581 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
583 if (!vType.getScalableDims()[0])
584 return baseVecLength;
587 Value vScale = vector::VectorScaleOp::create(rewriter, loc);
589 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
590 Value scalableVecLength =
591 arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
592 return scalableVecLength;
599template <
class LLVMRedIntrinOp,
class ScalarOp>
600static Value createIntegerReductionArithmeticOpLowering(
601 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
605 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
608 result = ScalarOp::create(rewriter, loc, accumulator,
result);
616template <
class LLVMRedIntrinOp>
617static Value createIntegerReductionComparisonOpLowering(
618 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
619 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
621 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
624 LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator,
result);
625 result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator,
result);
631template <
typename Source>
632struct VectorToScalarMapper;
634struct VectorToScalarMapper<
LLVM::vector_reduce_fmaximum> {
635 using Type = LLVM::MaximumOp;
638struct VectorToScalarMapper<
LLVM::vector_reduce_fminimum> {
639 using Type = LLVM::MinimumOp;
642struct VectorToScalarMapper<
LLVM::vector_reduce_fmax> {
643 using Type = LLVM::MaxNumOp;
646struct VectorToScalarMapper<
LLVM::vector_reduce_fmin> {
647 using Type = LLVM::MinNumOp;
651template <
class LLVMRedIntrinOp>
652static Value createFPReductionComparisonOpLowering(
653 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
654 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
656 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
659 result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
660 rewriter, loc,
result, accumulator);
667class MaskNeutralFMaximum {};
668class MaskNeutralFMinimum {};
672getMaskNeutralValue(MaskNeutralFMaximum,
673 const llvm::fltSemantics &floatSemantics) {
674 return llvm::APFloat::getSmallest(floatSemantics,
true);
678getMaskNeutralValue(MaskNeutralFMinimum,
679 const llvm::fltSemantics &floatSemantics) {
680 return llvm::APFloat::getLargest(floatSemantics,
false);
684template <
typename MaskNeutral>
685static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
688 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
689 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
691 return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
698template <
class LLVMRedIntrinOp,
class MaskNeutral>
700lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
703 Value mask, LLVM::FastmathFlagsAttr fmf) {
704 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
705 rewriter, loc, llvmType, vectorOperand.
getType());
706 const Value selectedVectorByMask = LLVM::SelectOp::create(
707 rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
708 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
709 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
712template <
class LLVMRedIntrinOp,
class ReductionNeutral>
714lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc,
716 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
717 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
718 llvmType, accumulator);
719 return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
720 accumulator, vectorOperand,
727template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
729lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
732 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
733 llvmType, accumulator);
734 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
735 accumulator, vectorOperand);
738template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
739static Value lowerPredicatedReductionWithStartValue(
740 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
742 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
743 llvmType, accumulator);
745 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
746 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
747 accumulator, vectorOperand,
751template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
752 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
753static Value lowerPredicatedReductionWithStartValue(
754 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
757 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
758 IntReductionNeutral>(
759 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
762 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
764 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
768class VectorReductionOpConversion
771 explicit VectorReductionOpConversion(
const LLVMTypeConverter &typeConv,
772 bool reassociateFPRed)
773 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
774 reassociateFPReductions(reassociateFPRed) {}
777 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
778 ConversionPatternRewriter &rewriter)
const override {
779 auto kind = reductionOp.getKind();
780 Type eltType = reductionOp.getDest().getType();
781 Type llvmType = typeConverter->convertType(eltType);
782 Value operand = adaptor.getVector();
783 Value acc = adaptor.getAcc();
784 Location loc = reductionOp.getLoc();
790 case vector::CombiningKind::ADD:
792 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
794 rewriter, loc, llvmType, operand, acc);
796 case vector::CombiningKind::MUL:
798 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
800 rewriter, loc, llvmType, operand, acc);
802 case vector::CombiningKind::MINUI:
803 result = createIntegerReductionComparisonOpLowering<
804 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
805 LLVM::ICmpPredicate::ule);
807 case vector::CombiningKind::MINSI:
808 result = createIntegerReductionComparisonOpLowering<
809 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
810 LLVM::ICmpPredicate::sle);
812 case vector::CombiningKind::MAXUI:
813 result = createIntegerReductionComparisonOpLowering<
814 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
815 LLVM::ICmpPredicate::uge);
817 case vector::CombiningKind::MAXSI:
818 result = createIntegerReductionComparisonOpLowering<
819 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
820 LLVM::ICmpPredicate::sge);
822 case vector::CombiningKind::AND:
824 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
826 rewriter, loc, llvmType, operand, acc);
828 case vector::CombiningKind::OR:
830 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
832 rewriter, loc, llvmType, operand, acc);
834 case vector::CombiningKind::XOR:
836 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
838 rewriter, loc, llvmType, operand, acc);
843 rewriter.replaceOp(reductionOp,
result);
848 if (!isa<FloatType>(eltType))
851 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
852 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
853 reductionOp.getContext(),
855 fmf = LLVM::FastmathFlagsAttr::get(
856 reductionOp.getContext(),
857 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
858 : LLVM::FastmathFlags::none));
862 if (kind == vector::CombiningKind::ADD) {
863 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
864 ReductionNeutralZero>(
865 rewriter, loc, llvmType, operand, acc, fmf);
866 }
else if (kind == vector::CombiningKind::MUL) {
867 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
868 ReductionNeutralFPOne>(
869 rewriter, loc, llvmType, operand, acc, fmf);
870 }
else if (kind == vector::CombiningKind::MINIMUMF) {
872 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
873 rewriter, loc, llvmType, operand, acc, fmf);
874 }
else if (kind == vector::CombiningKind::MAXIMUMF) {
876 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
877 rewriter, loc, llvmType, operand, acc, fmf);
878 }
else if (kind == vector::CombiningKind::MINNUMF) {
879 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
880 rewriter, loc, llvmType, operand, acc, fmf);
881 }
else if (kind == vector::CombiningKind::MAXNUMF) {
882 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
883 rewriter, loc, llvmType, operand, acc, fmf);
888 rewriter.replaceOp(reductionOp,
result);
893 const bool reassociateFPReductions;
904template <
class MaskedOp>
905class VectorMaskOpConversionBase
908 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
911 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
912 ConversionPatternRewriter &rewriter)
const final {
914 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
917 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
921 virtual LogicalResult
922 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
923 vector::MaskableOpInterface maskableOp,
924 ConversionPatternRewriter &rewriter)
const = 0;
927class MaskedReductionOpConversion
928 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
931 using VectorMaskOpConversionBase<
932 vector::ReductionOp>::VectorMaskOpConversionBase;
934 LogicalResult matchAndRewriteMaskableOp(
935 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
936 ConversionPatternRewriter &rewriter)
const override {
937 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
938 auto kind = reductionOp.getKind();
939 Type eltType = reductionOp.getDest().getType();
940 Type llvmType = typeConverter->convertType(eltType);
941 Value operand = reductionOp.getVector();
942 Value acc = reductionOp.getAcc();
943 Location loc = reductionOp.getLoc();
945 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
946 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
947 reductionOp.getContext(),
952 case vector::CombiningKind::ADD:
953 result = lowerPredicatedReductionWithStartValue<
954 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
955 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
958 case vector::CombiningKind::MUL:
959 result = lowerPredicatedReductionWithStartValue<
960 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
961 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
964 case vector::CombiningKind::MINUI:
965 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
966 ReductionNeutralUIntMax>(
967 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
969 case vector::CombiningKind::MINSI:
970 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
971 ReductionNeutralSIntMax>(
972 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
974 case vector::CombiningKind::MAXUI:
975 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
976 ReductionNeutralUIntMin>(
977 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
979 case vector::CombiningKind::MAXSI:
980 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
981 ReductionNeutralSIntMin>(
982 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
984 case vector::CombiningKind::AND:
985 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
986 ReductionNeutralAllOnes>(
987 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
989 case vector::CombiningKind::OR:
990 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
991 ReductionNeutralZero>(
992 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
994 case vector::CombiningKind::XOR:
995 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
996 ReductionNeutralZero>(
997 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
999 case vector::CombiningKind::MINNUMF:
1000 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1001 ReductionNeutralFPMax>(
1002 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1004 case vector::CombiningKind::MAXNUMF:
1005 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1006 ReductionNeutralFPMin>(
1007 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1009 case CombiningKind::MAXIMUMF:
1010 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1011 MaskNeutralFMaximum>(
1012 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1014 case CombiningKind::MINIMUMF:
1015 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1016 MaskNeutralFMinimum>(
1017 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1022 rewriter.replaceOp(maskOp,
result);
1027class VectorShuffleOpConversion
1030 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
1033 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1034 ConversionPatternRewriter &rewriter)
const override {
1035 auto loc = shuffleOp->getLoc();
1036 auto v1Type = shuffleOp.getV1VectorType();
1037 auto v2Type = shuffleOp.getV2VectorType();
1038 auto vectorType = shuffleOp.getResultVectorType();
1039 Type llvmType = typeConverter->convertType(vectorType);
1040 ArrayRef<int64_t> mask = shuffleOp.getMask();
1047 int64_t rank = vectorType.getRank();
1049 bool wellFormed0DCase =
1050 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1051 bool wellFormedNDCase =
1052 v1Type.getRank() == rank && v2Type.getRank() == rank;
1053 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1058 if (rank <= 1 && v1Type == v2Type) {
1059 Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1060 rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1061 llvm::to_vector_of<int32_t>(mask));
1062 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1067 int64_t v1Dim = v1Type.getDimSize(0);
1069 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1070 eltType = arrayType.getElementType();
1072 eltType = cast<VectorType>(llvmType).getElementType();
1073 Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1075 for (int64_t extPos : mask) {
1076 Value value = adaptor.getV1();
1077 if (extPos >= v1Dim) {
1079 value = adaptor.getV2();
1081 Value extract =
extractOne(rewriter, *getTypeConverter(), loc, value,
1082 eltType, rank, extPos);
1083 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1084 llvmType, rank, insPos++);
1086 rewriter.replaceOp(shuffleOp, insert);
1091class VectorExtractOpConversion
1094 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1097 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1098 ConversionPatternRewriter &rewriter)
const override {
1099 auto loc = extractOp->getLoc();
1100 auto resultType = extractOp.getResult().getType();
1101 auto llvmResultType = typeConverter->convertType(resultType);
1103 if (!llvmResultType)
1107 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1121 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1125 bool extractsScalar =
static_cast<int64_t
>(positionVec.size()) ==
1126 extractOp.getSourceVectorType().getRank();
1130 if (extractOp.getSourceVectorType().getRank() == 0) {
1131 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1132 positionVec.push_back(rewriter.getZeroAttr(idxType));
1135 Value extracted = adaptor.getSource();
1136 if (extractsAggregate) {
1137 ArrayRef<OpFoldResult> position(positionVec);
1138 if (extractsScalar) {
1142 position = position.drop_back();
1145 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1148 extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1152 if (extractsScalar) {
1153 extracted = LLVM::ExtractElementOp::create(
1154 rewriter, loc, extracted,
1158 rewriter.replaceOp(extractOp, extracted);
1179 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1182 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1183 ConversionPatternRewriter &rewriter)
const override {
1184 VectorType vType = fmaOp.getVectorType();
1185 if (vType.getRank() > 1)
1188 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1189 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1194class VectorInsertOpConversion
1197 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1200 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1201 ConversionPatternRewriter &rewriter)
const override {
1202 auto loc = insertOp->getLoc();
1203 auto destVectorType = insertOp.getDestVectorType();
1204 auto llvmResultType = typeConverter->convertType(destVectorType);
1206 if (!llvmResultType)
1210 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1232 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1234 bool insertIntoInnermostDim =
1235 static_cast<int64_t
>(positionVec.size()) == destVectorType.getRank();
1237 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1238 positionVec.begin(),
1239 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1240 OpFoldResult positionOfScalarWithin1DVector;
1241 if (destVectorType.getRank() == 0) {
1244 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1245 positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1246 }
else if (insertIntoInnermostDim) {
1247 positionOfScalarWithin1DVector = positionVec.back();
1253 Value sourceAggregate = adaptor.getValueToStore();
1254 if (insertIntoInnermostDim) {
1257 if (isNestedAggregate) {
1260 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1261 llvm::IsaPred<Attribute>)) {
1265 sourceAggregate = LLVM::ExtractValueOp::create(
1266 rewriter, loc, adaptor.getDest(),
1271 sourceAggregate = adaptor.getDest();
1274 sourceAggregate = LLVM::InsertElementOp::create(
1275 rewriter, loc, sourceAggregate.
getType(), sourceAggregate,
1276 adaptor.getValueToStore(),
1280 Value
result = sourceAggregate;
1281 if (isNestedAggregate) {
1282 result = LLVM::InsertValueOp::create(
1283 rewriter, loc, adaptor.getDest(), sourceAggregate,
1287 rewriter.replaceOp(insertOp,
result);
1293struct VectorScalableInsertOpLowering
1295 using ConvertOpToLLVMPattern<
1296 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1299 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1300 ConversionPatternRewriter &rewriter)
const override {
1301 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1302 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1308struct VectorScalableExtractOpLowering
1310 using ConvertOpToLLVMPattern<
1311 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1314 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1315 ConversionPatternRewriter &rewriter)
const override {
1316 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1317 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1318 adaptor.getSource(), adaptor.getPos());
1351 setHasBoundedRewriteRecursion();
1354 LogicalResult matchAndRewrite(FMAOp op,
1355 PatternRewriter &rewriter)
const override {
1356 auto vType = op.getVectorType();
1357 if (vType.getRank() < 2)
1360 auto loc = op.getLoc();
1361 auto elemType = vType.getElementType();
1362 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1364 Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1365 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1366 Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1367 Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1368 Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1369 Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1370 desc = InsertOp::create(rewriter, loc, fma, desc, i);
1379static std::optional<SmallVector<int64_t, 4>>
1380computeContiguousStrides(MemRefType memRefType) {
1383 if (
failed(memRefType.getStridesAndOffset(strides, offset)))
1384 return std::nullopt;
1385 if (!strides.empty() && strides.back() != 1)
1386 return std::nullopt;
1388 if (memRefType.getLayout().isIdentity())
1395 auto sizes = memRefType.getShape();
1397 if (ShapedType::isDynamic(sizes[
index + 1]) ||
1398 ShapedType::isDynamic(strides[
index]) ||
1399 ShapedType::isDynamic(strides[
index + 1]))
1400 return std::nullopt;
1402 return std::nullopt;
1407class VectorTypeCastOpConversion
1410 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1413 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1414 ConversionPatternRewriter &rewriter)
const override {
1415 auto loc = castOp->getLoc();
1416 MemRefType sourceMemRefType =
1417 cast<MemRefType>(castOp.getOperand().getType());
1418 MemRefType targetMemRefType = castOp.getType();
1421 if (!sourceMemRefType.hasStaticShape() ||
1422 !targetMemRefType.hasStaticShape())
1425 auto llvmSourceDescriptorTy =
1426 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1427 if (!llvmSourceDescriptorTy)
1429 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1431 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1432 typeConverter->convertType(targetMemRefType));
1433 if (!llvmTargetDescriptorTy)
1437 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1440 auto targetStrides = computeContiguousStrides(targetMemRefType);
1444 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1447 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1450 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1452 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1453 desc.setAllocatedPtr(rewriter, loc, allocated);
1456 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1457 desc.setAlignedPtr(rewriter, loc, ptr);
1459 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1460 auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1461 desc.setOffset(rewriter, loc, zero);
1464 for (
const auto &indexedSize :
1465 llvm::enumerate(targetMemRefType.getShape())) {
1466 int64_t index = indexedSize.index();
1468 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1469 auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1470 desc.setSize(rewriter, loc, index, size);
1471 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1472 (*targetStrides)[index]);
1474 LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1475 desc.setStride(rewriter, loc, index, stride);
1478 rewriter.replaceOp(castOp, {desc});
1485class VectorCreateMaskOpConversion
1486 :
public OpConversionPattern<vector::CreateMaskOp> {
1488 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1489 bool enableIndexOpt)
1490 : OpConversionPattern<vector::CreateMaskOp>(context),
1491 force32BitVectorIndices(enableIndexOpt) {}
1494 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1495 ConversionPatternRewriter &rewriter)
const override {
1496 auto dstType = op.getType();
1497 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1499 IntegerType idxType =
1500 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1501 auto loc = op->getLoc();
1502 Value
indices = LLVM::StepVectorOp::create(
1507 adaptor.getOperands()[0]);
1508 Value bounds = BroadcastOp::create(rewriter, loc,
indices.getType(), bound);
1509 Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1511 rewriter.replaceOp(op, comp);
1516 const bool force32BitVectorIndices;
1520 SymbolTableCollection *symbolTables =
nullptr;
1523 explicit VectorPrintOpConversion(
1524 const LLVMTypeConverter &typeConverter,
1525 SymbolTableCollection *symbolTables =
nullptr)
1526 : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1527 symbolTables(symbolTables) {}
1543 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1544 ConversionPatternRewriter &rewriter)
const override {
1545 auto parent =
printOp->getParentOfType<ModuleOp>();
1551 if (
auto value = adaptor.getSource()) {
1553 if (isa<VectorType>(printType)) {
1557 if (
failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1561 auto punct =
printOp.getPunctuation();
1562 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1565 *stringLiteral, *getTypeConverter(),
1567 if (createResult.failed())
1570 }
else if (punct != PrintPunctuation::NoPunctuation) {
1571 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1573 case PrintPunctuation::Close:
1576 case PrintPunctuation::Open:
1579 case PrintPunctuation::Comma:
1582 case PrintPunctuation::NewLine:
1586 llvm_unreachable(
"unexpected punctuation");
1591 emitCall(rewriter,
printOp->getLoc(), op.value());
1599 enum class PrintConversion {
1608 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1609 ModuleOp parent, Location loc, Type printType,
1610 Value value)
const {
1611 if (typeConverter->convertType(printType) ==
nullptr)
1615 PrintConversion conversion = PrintConversion::None;
1616 FailureOr<Operation *> printer;
1622 conversion = PrintConversion::Bitcast16;
1625 conversion = PrintConversion::Bitcast16;
1629 }
else if (
auto intTy = dyn_cast<IntegerType>(printType)) {
1633 unsigned width = intTy.getWidth();
1634 if (intTy.isUnsigned()) {
1637 conversion = PrintConversion::ZeroExt64;
1644 assert(intTy.isSignless() || intTy.isSigned());
1649 conversion = PrintConversion::ZeroExt64;
1650 else if (width < 64)
1651 conversion = PrintConversion::SignExt64;
1658 }
else if (
auto floatTy = dyn_cast<FloatType>(printType)) {
1661 llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
1662 Value semValue = LLVM::ConstantOp::create(
1663 rewriter, loc, rewriter.getI32Type(),
1664 rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
1666 LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
1669 emitCall(rewriter, loc, printer.value(),
1678 switch (conversion) {
1679 case PrintConversion::ZeroExt64:
1680 value = arith::ExtUIOp::create(
1681 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1683 case PrintConversion::SignExt64:
1684 value = arith::ExtSIOp::create(
1685 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1687 case PrintConversion::Bitcast16:
1688 value = LLVM::BitcastOp::create(
1689 rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
1691 case PrintConversion::None:
1694 emitCall(rewriter, loc, printer.value(), value);
1699 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1701 LLVM::CallOp::create(rewriter, loc,
TypeRange(), SymbolRefAttr::get(ref),
1709struct VectorBroadcastScalarToLowRankLowering
1711 using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
1714 matchAndRewrite(vector::BroadcastOp
broadcast, OpAdaptor adaptor,
1715 ConversionPatternRewriter &rewriter)
const override {
1716 if (isa<VectorType>(
broadcast.getSourceType()))
1717 return rewriter.notifyMatchFailure(
1718 broadcast,
"broadcast from vector type not handled");
1721 if (resultType.getRank() > 1)
1722 return rewriter.notifyMatchFailure(
broadcast,
1723 "broadcast to 2+-d handled elsewhere");
1729 auto zero = LLVM::ConstantOp::create(
1731 typeConverter->convertType(rewriter.getIntegerType(32)),
1732 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1735 if (resultType.getRank() == 0) {
1736 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1737 broadcast, vectorType, poison, adaptor.getSource(), zero);
1742 LLVM::InsertElementOp::create(rewriter,
broadcast.
getLoc(), vectorType,
1743 poison, adaptor.getSource(), zero);
1747 SmallVector<int32_t> zeroValues(width, 0);
1750 auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
1761struct VectorBroadcastScalarToNdLowering
1763 using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
1766 matchAndRewrite(BroadcastOp
broadcast, OpAdaptor adaptor,
1767 ConversionPatternRewriter &rewriter)
const override {
1768 if (isa<VectorType>(
broadcast.getSourceType()))
1769 return rewriter.notifyMatchFailure(
1770 broadcast,
"broadcast from vector type not handled");
1773 if (resultType.getRank() <= 1)
1774 return rewriter.notifyMatchFailure(
1775 broadcast,
"broadcast to 1-d or 0-d handled elsewhere");
1779 auto vectorTypeInfo =
1781 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1782 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1783 if (!llvmNDVectorTy || !llvm1DVectorTy)
1787 Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1791 Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1792 auto zero = LLVM::ConstantOp::create(
1793 rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1794 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1795 Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1796 vdesc, adaptor.getSource(), zero);
1799 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1800 SmallVector<int32_t> zeroValues(width, 0);
1801 v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1805 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1806 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1815struct VectorInterleaveOpLowering
1820 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1821 ConversionPatternRewriter &rewriter)
const override {
1822 VectorType resultType = interleaveOp.getResultVectorType();
1824 if (resultType.getRank() != 1)
1825 return rewriter.notifyMatchFailure(interleaveOp,
1826 "InterleaveOp not rank 1");
1828 if (resultType.isScalable()) {
1829 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1830 interleaveOp, typeConverter->convertType(resultType),
1831 adaptor.getLhs(), adaptor.getRhs());
1838 int64_t resultVectorSize = resultType.getNumElements();
1839 SmallVector<int32_t> interleaveShuffleMask;
1840 interleaveShuffleMask.reserve(resultVectorSize);
1841 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1842 interleaveShuffleMask.push_back(i);
1843 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1845 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1846 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1847 interleaveShuffleMask);
1854struct VectorDeinterleaveOpLowering
1859 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1860 ConversionPatternRewriter &rewriter)
const override {
1861 VectorType resultType = deinterleaveOp.getResultVectorType();
1862 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1863 auto loc = deinterleaveOp.getLoc();
1867 if (resultType.getRank() != 1)
1868 return rewriter.notifyMatchFailure(deinterleaveOp,
1869 "DeinterleaveOp not rank 1");
1871 if (resultType.isScalable()) {
1872 const auto *llvmTypeConverter = this->getTypeConverter();
1873 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1874 auto packedOpResults =
1875 llvmTypeConverter->packOperationResults(deinterleaveResults);
1876 auto intrinsic = LLVM::vector_deinterleave2::create(
1877 rewriter, loc, packedOpResults, adaptor.getSource());
1879 auto evenResult = LLVM::ExtractValueOp::create(
1880 rewriter, loc, intrinsic->getResult(0), 0);
1881 auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1882 intrinsic->getResult(0), 1);
1884 rewriter.replaceOp(deinterleaveOp,
ValueRange{evenResult, oddResult});
1891 int64_t resultVectorSize = resultType.getNumElements();
1892 SmallVector<int32_t> evenShuffleMask;
1893 SmallVector<int32_t> oddShuffleMask;
1895 evenShuffleMask.reserve(resultVectorSize);
1896 oddShuffleMask.reserve(resultVectorSize);
1898 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1900 evenShuffleMask.push_back(i);
1902 oddShuffleMask.push_back(i);
1905 auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1906 auto evenShuffle = LLVM::ShuffleVectorOp::create(
1907 rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1908 auto oddShuffle = LLVM::ShuffleVectorOp::create(
1909 rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1911 rewriter.replaceOp(deinterleaveOp,
ValueRange{evenShuffle, oddShuffle});
1917struct VectorFromElementsLowering
1922 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1923 ConversionPatternRewriter &rewriter)
const override {
1924 Location loc = fromElementsOp.getLoc();
1925 VectorType vectorType = fromElementsOp.getType();
1929 if (vectorType.getRank() > 1)
1930 return rewriter.notifyMatchFailure(fromElementsOp,
1931 "rank > 1 vectors are not supported");
1932 Type llvmType = typeConverter->convertType(vectorType);
1933 Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1934 Value
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1935 for (
auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1937 LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1938 result = LLVM::InsertElementOp::create(rewriter, loc, llvmType,
result,
1941 rewriter.replaceOp(fromElementsOp,
result);
1947struct VectorToElementsLowering
1952 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1953 ConversionPatternRewriter &rewriter)
const override {
1954 Location loc = toElementsOp.getLoc();
1955 auto idxType = typeConverter->convertType(rewriter.getIndexType());
1956 Value source = adaptor.getSource();
1958 SmallVector<Value> results(toElementsOp->getNumResults());
1959 for (
auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1961 if (element.use_empty())
1964 auto constIdx = LLVM::ConstantOp::create(
1965 rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
1966 auto llvmType = typeConverter->convertType(element.getType());
1968 Value
result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1973 rewriter.replaceOp(toElementsOp, results);
1979struct VectorScalableStepOpLowering
1984 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1985 ConversionPatternRewriter &rewriter)
const override {
1986 auto resultType = cast<VectorType>(stepOp.getType());
1987 if (!resultType.isScalable()) {
1990 Type llvmType = typeConverter->convertType(stepOp.getType());
1991 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
2006class ContractionOpToMatmulOpLowering
2009 using MaskableOpRewritePattern::MaskableOpRewritePattern;
2011 ContractionOpToMatmulOpLowering(MLIRContext *context,
2012 PatternBenefit benefit = 100)
2013 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2016 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2017 PatternRewriter &rewriter)
const override;
2037FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2038 vector::ContractionOp op, MaskingOpInterface maskOp,
2044 auto iteratorTypes = op.getIteratorTypes().getValue();
2050 Type opResType = op.getType();
2051 VectorType vecType = dyn_cast<VectorType>(opResType);
2052 if (vecType && vecType.isScalable()) {
2057 Type elementType = op.getLhsType().getElementType();
2061 Type dstElementType = vecType ? vecType.getElementType() : opResType;
2062 if (elementType != dstElementType)
2067 MLIRContext *ctx = op.getContext();
2068 Location loc = op.getLoc();
2072 Value
lhs = op.getLhs();
2073 auto lhsMap = op.getIndexingMapsArray()[0];
2075 lhs = vector::TransposeOp::create(rew, loc,
lhs, ArrayRef<int64_t>{1, 0});
2080 Value
rhs = op.getRhs();
2081 auto rhsMap = op.getIndexingMapsArray()[1];
2083 rhs = vector::TransposeOp::create(rew, loc,
rhs, ArrayRef<int64_t>{1, 0});
2088 VectorType lhsType = cast<VectorType>(
lhs.getType());
2089 VectorType rhsType = cast<VectorType>(
rhs.getType());
2090 int64_t lhsRows = lhsType.getDimSize(0);
2091 int64_t lhsColumns = lhsType.getDimSize(1);
2092 int64_t rhsColumns = rhsType.getDimSize(1);
2094 Type flattenedLHSType =
2095 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2096 lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType,
lhs);
2098 Type flattenedRHSType =
2099 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2100 rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType,
rhs);
2102 Value
mul = LLVM::MatrixMultiplyOp::create(
2104 VectorType::get(lhsRows * rhsColumns,
2105 cast<VectorType>(
lhs.getType()).getElementType()),
2106 lhs,
rhs, lhsRows, lhsColumns, rhsColumns);
2108 mul = vector::ShapeCastOp::create(
2110 VectorType::get({lhsRows, rhsColumns},
2115 auto accMap = op.getIndexingMapsArray()[2];
2117 mul = vector::TransposeOp::create(rew, loc,
mul, ArrayRef<int64_t>{1, 0});
2119 llvm_unreachable(
"invalid contraction semantics");
2121 Value res = isa<IntegerType>(elementType)
2122 ?
static_cast<Value
>(
2123 arith::AddIOp::create(rew, loc, op.getAcc(),
mul))
2124 : static_cast<Value>(
2125 arith::AddFOp::create(rew, loc, op.getAcc(),
mul));
2143class TransposeOpToMatrixTransposeOpLowering
2144 :
public OpRewritePattern<vector::TransposeOp> {
2148 LogicalResult matchAndRewrite(vector::TransposeOp op,
2149 PatternRewriter &rewriter)
const override {
2150 auto loc = op.getLoc();
2152 Value input = op.getVector();
2153 VectorType inputType = op.getSourceVectorType();
2154 VectorType resType = op.getResultVectorType();
2156 if (inputType.isScalable())
2158 op,
"This lowering does not support scalable vectors");
2161 ArrayRef<int64_t> transp = op.getPermutation();
2163 if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2167 Type flattenedType =
2168 VectorType::get(resType.getNumElements(), resType.getElementType());
2170 vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2173 Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2174 matrix, rows, columns);
2189 patterns.add<ContractionOpToMatmulOpLowering>(
patterns.getContext(), benefit);
2194 patterns.add<TransposeOpToMatrixTransposeOpLowering>(
patterns.getContext(),
2201 bool reassociateFPReductions,
bool force32BitVectorIndices,
2202 bool useVectorAlignment) {
2205 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2206 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2207 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2208 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2209 VectorLoadStoreConversion<vector::StoreOp>,
2210 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2211 VectorGatherOpConversion, VectorScatterOpConversion>(
2212 converter, useVectorAlignment);
2213 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2214 VectorExtractOpConversion, VectorFMAOp1DConversion,
2215 VectorInsertOpConversion, VectorPrintOpConversion,
2216 VectorTypeCastOpConversion, VectorScaleOpConversion,
2217 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2218 VectorBroadcastScalarToLowRankLowering,
2219 VectorBroadcastScalarToNdLowering,
2220 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2221 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2222 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2223 VectorToElementsLowering, VectorScalableStepOpLowering>(
2230 void loadDependentDialects(
MLIRContext *context)
const final {
2231 context->loadDialect<LLVM::LLVMDialect>();
2236 void populateConvertToLLVMConversionPatterns(
2237 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
2238 RewritePatternSet &
patterns)
const final {
2247 dialect->addInterfaces<VectorToLLVMDialectInterface>();
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.