31#include "llvm/ADT/APFloat.h"
32#include "llvm/IR/LLVMContext.h"
33#include "llvm/Support/Casting.h"
45 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
47 auto idxType = rewriter.getIndexType();
48 auto constant = LLVM::ConstantOp::create(
49 rewriter, loc, typeConverter.convertType(idxType),
50 rewriter.getIntegerAttr(idxType, pos));
51 return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
54 return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
62 auto idxType = rewriter.getIndexType();
63 auto constant = LLVM::ConstantOp::create(
64 rewriter, loc, typeConverter.convertType(idxType),
65 rewriter.getIntegerAttr(idxType, pos));
66 return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
69 return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
74 VectorType vectorType,
unsigned &align) {
75 Type convertedVectorTy = typeConverter.convertType(vectorType);
76 if (!convertedVectorTy)
79 llvm::LLVMContext llvmContext;
89 MemRefType memrefType,
unsigned &align) {
90 Type elementTy = typeConverter.convertType(memrefType.getElementType());
96 llvm::LLVMContext llvmContext;
109 VectorType vectorType,
110 MemRefType memrefType,
unsigned &align,
111 bool useVectorAlignment) {
112 if (useVectorAlignment) {
127 if (!memRefType.isLastDimUnitStride())
137 MemRefType memRefType,
Value llvmMemref,
Value base,
140 "unsupported memref type");
141 assert(vectorType.getRank() == 1 &&
"expected a 1-d vector type");
145 vectorType.getScalableDims()[0]);
146 return LLVM::GEPOp::create(
147 rewriter, loc, ptrsType,
148 typeConverter.convertType(memRefType.getElementType()), base,
index);
155 if (
auto attr = dyn_cast<Attribute>(foldResult)) {
156 auto intAttr = cast<IntegerAttr>(attr);
157 return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
160 return cast<Value>(foldResult);
166using VectorScaleOpConversion =
170class VectorBitCastOpConversion
173 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
176 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter)
const override {
179 VectorType resultTy = bitCastOp.getResultVectorType();
180 if (resultTy.getRank() > 1)
182 Type newResultTy = typeConverter->convertType(resultTy);
183 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
184 adaptor.getOperands()[0]);
192static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193 vector::LoadOpAdaptor adaptor,
194 VectorType vectorTy,
Value ptr,
unsigned align,
195 ConversionPatternRewriter &rewriter) {
196 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy,
ptr, align,
198 loadOp.getNontemporal());
201static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202 vector::MaskedLoadOpAdaptor adaptor,
203 VectorType vectorTy,
Value ptr,
unsigned align,
204 ConversionPatternRewriter &rewriter) {
205 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206 loadOp, vectorTy,
ptr, adaptor.getMask(), adaptor.getPassThru(), align);
209static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210 vector::StoreOpAdaptor adaptor,
211 VectorType vectorTy,
Value ptr,
unsigned align,
212 ConversionPatternRewriter &rewriter) {
213 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
215 storeOp.getNontemporal());
218static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219 vector::MaskedStoreOpAdaptor adaptor,
220 VectorType vectorTy,
Value ptr,
unsigned align,
221 ConversionPatternRewriter &rewriter) {
222 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
223 storeOp, adaptor.getValueToStore(),
ptr, adaptor.getMask(), align);
228template <
class LoadOrStoreOp>
231 explicit VectorLoadStoreConversion(
const LLVMTypeConverter &typeConv,
233 : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
234 useVectorAlignment(useVectorAlign) {}
235 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
238 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239 typename LoadOrStoreOp::Adaptor adaptor,
240 ConversionPatternRewriter &rewriter)
const override {
242 VectorType vectorTy = loadOrStoreOp.getVectorType();
243 if (vectorTy.getRank() > 1)
246 auto loc = loadOrStoreOp->getLoc();
247 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
251 unsigned align = loadOrStoreOp.getAlignment().value_or(0);
254 memRefTy, align, useVectorAlignment)))
255 return rewriter.notifyMatchFailure(loadOrStoreOp,
256 "could not resolve alignment");
259 auto vtype = cast<VectorType>(
260 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
262 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
263 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
273 const bool useVectorAlignment;
277class VectorGatherOpConversion
280 explicit VectorGatherOpConversion(
const LLVMTypeConverter &typeConv,
282 : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
283 useVectorAlignment(useVectorAlign) {}
284 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
287 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
289 Location loc = gather->getLoc();
290 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
291 assert(memRefType &&
"The base should be bufferized");
294 return rewriter.notifyMatchFailure(gather,
"memref type not supported");
296 VectorType vType = gather.getVectorType();
297 if (vType.getRank() > 1) {
298 return rewriter.notifyMatchFailure(
299 gather,
"only 1-D vectors can be lowered to LLVM");
304 unsigned align = gather.getAlignment().value_or(0);
307 memRefType, align, useVectorAlignment)))
308 return rewriter.notifyMatchFailure(gather,
"could not resolve alignment");
312 adaptor.getBase(), adaptor.getOffsets());
313 Value base = adaptor.getBase();
315 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
316 base, ptr, adaptor.getIndices(), vType);
319 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
320 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
321 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
330 const bool useVectorAlignment;
334class VectorScatterOpConversion
337 explicit VectorScatterOpConversion(
const LLVMTypeConverter &typeConv,
339 : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
340 useVectorAlignment(useVectorAlign) {}
342 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
345 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter)
const override {
347 auto loc = scatter->getLoc();
348 auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
349 assert(memRefType &&
"The base should be bufferized");
352 return rewriter.notifyMatchFailure(scatter,
"memref type not supported");
354 VectorType vType = scatter.getVectorType();
355 if (vType.getRank() > 1) {
356 return rewriter.notifyMatchFailure(
357 scatter,
"only 1-D vectors can be lowered to LLVM");
362 unsigned align = scatter.getAlignment().value_or(0);
365 memRefType, align, useVectorAlignment)))
366 return rewriter.notifyMatchFailure(scatter,
367 "could not resolve alignment");
371 adaptor.getBase(), adaptor.getOffsets());
373 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
374 adaptor.getBase(), ptr, adaptor.getIndices(), vType);
377 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
378 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
379 rewriter.getI32IntegerAttr(align));
388 const bool useVectorAlignment;
392class VectorExpandLoadOpConversion
395 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
398 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter)
const override {
400 auto loc = expand->getLoc();
401 MemRefType memRefType = expand.getMemRefType();
404 auto vtype = typeConverter->convertType(expand.getVectorType());
406 adaptor.getBase(), adaptor.getIndices());
411 uint64_t alignment = expand.getAlignment().value_or(1);
413 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
414 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
421class VectorCompressStoreOpConversion
424 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
427 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter)
const override {
429 auto loc = compress->getLoc();
430 MemRefType memRefType = compress.getMemRefType();
434 adaptor.getBase(), adaptor.getIndices());
439 uint64_t alignment = compress.getAlignment().value_or(1);
441 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
442 compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
448class ReductionNeutralZero {};
449class ReductionNeutralIntOne {};
450class ReductionNeutralFPOne {};
451class ReductionNeutralAllOnes {};
452class ReductionNeutralSIntMin {};
453class ReductionNeutralUIntMin {};
454class ReductionNeutralSIntMax {};
455class ReductionNeutralUIntMax {};
456class ReductionNeutralFPMin {};
457class ReductionNeutralFPMax {};
460static Value createReductionNeutralValue(ReductionNeutralZero neutral,
461 ConversionPatternRewriter &rewriter,
463 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
464 rewriter.getZeroAttr(llvmType));
468static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
469 ConversionPatternRewriter &rewriter,
471 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
472 rewriter.getIntegerAttr(llvmType, 1));
476static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
477 ConversionPatternRewriter &rewriter,
479 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
480 rewriter.getFloatAttr(llvmType, 1.0));
484static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
485 ConversionPatternRewriter &rewriter,
487 return LLVM::ConstantOp::create(
488 rewriter, loc, llvmType,
489 rewriter.getIntegerAttr(
494static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
495 ConversionPatternRewriter &rewriter,
497 return LLVM::ConstantOp::create(
498 rewriter, loc, llvmType,
499 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
504static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
505 ConversionPatternRewriter &rewriter,
507 return LLVM::ConstantOp::create(
508 rewriter, loc, llvmType,
509 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
514static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
515 ConversionPatternRewriter &rewriter,
517 return LLVM::ConstantOp::create(
518 rewriter, loc, llvmType,
519 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
524static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
525 ConversionPatternRewriter &rewriter,
527 return LLVM::ConstantOp::create(
528 rewriter, loc, llvmType,
529 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
534static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
535 ConversionPatternRewriter &rewriter,
537 auto floatType = cast<FloatType>(llvmType);
538 return LLVM::ConstantOp::create(
539 rewriter, loc, llvmType,
540 rewriter.getFloatAttr(
541 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
546static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
547 ConversionPatternRewriter &rewriter,
549 auto floatType = cast<FloatType>(llvmType);
550 return LLVM::ConstantOp::create(
551 rewriter, loc, llvmType,
552 rewriter.getFloatAttr(
553 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
559template <
class ReductionNeutral>
560static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
566 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
573static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
575 VectorType vType = cast<VectorType>(llvmType);
576 auto vShape = vType.getShape();
577 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
579 Value baseVecLength = LLVM::ConstantOp::create(
580 rewriter, loc, rewriter.getI32Type(),
581 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
583 if (!vType.getScalableDims()[0])
584 return baseVecLength;
587 Value vScale = vector::VectorScaleOp::create(rewriter, loc);
589 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
590 Value scalableVecLength =
591 arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
592 return scalableVecLength;
599template <
class LLVMRedIntrinOp,
class ScalarOp>
600static Value createIntegerReductionArithmeticOpLowering(
601 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
605 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
608 result = ScalarOp::create(rewriter, loc, accumulator,
result);
616template <
class LLVMRedIntrinOp>
617static Value createIntegerReductionComparisonOpLowering(
618 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
619 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
621 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
624 LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator,
result);
625 result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator,
result);
631template <
typename Source>
632struct VectorToScalarMapper;
634struct VectorToScalarMapper<
LLVM::vector_reduce_fmaximum> {
635 using Type = LLVM::MaximumOp;
638struct VectorToScalarMapper<
LLVM::vector_reduce_fminimum> {
639 using Type = LLVM::MinimumOp;
642struct VectorToScalarMapper<
LLVM::vector_reduce_fmax> {
643 using Type = LLVM::MaxNumOp;
646struct VectorToScalarMapper<
LLVM::vector_reduce_fmin> {
647 using Type = LLVM::MinNumOp;
651template <
class LLVMRedIntrinOp>
652static Value createFPReductionComparisonOpLowering(
653 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
654 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
656 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
659 result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
660 rewriter, loc,
result, accumulator);
667class MaskNeutralFMaximum {};
668class MaskNeutralFMinimum {};
672getMaskNeutralValue(MaskNeutralFMaximum,
673 const llvm::fltSemantics &floatSemantics) {
674 return llvm::APFloat::getSmallest(floatSemantics,
true);
678getMaskNeutralValue(MaskNeutralFMinimum,
679 const llvm::fltSemantics &floatSemantics) {
680 return llvm::APFloat::getLargest(floatSemantics,
false);
684template <
typename MaskNeutral>
685static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
688 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
689 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
691 return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
698template <
class LLVMRedIntrinOp,
class MaskNeutral>
700lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
703 Value mask, LLVM::FastmathFlagsAttr fmf) {
704 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
705 rewriter, loc, llvmType, vectorOperand.
getType());
706 const Value selectedVectorByMask = LLVM::SelectOp::create(
707 rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
708 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
709 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
712template <
class LLVMRedIntrinOp,
class ReductionNeutral>
714lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc,
716 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
717 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
718 llvmType, accumulator);
719 return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
720 accumulator, vectorOperand,
727template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
729lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
732 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
733 llvmType, accumulator);
734 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
735 accumulator, vectorOperand);
738template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
739static Value lowerPredicatedReductionWithStartValue(
740 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
742 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
743 llvmType, accumulator);
745 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
746 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
747 accumulator, vectorOperand,
751template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
752 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
753static Value lowerPredicatedReductionWithStartValue(
754 ConversionPatternRewriter &rewriter,
Location loc,
Type llvmType,
757 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
758 IntReductionNeutral>(
759 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
762 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
764 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
768class VectorReductionOpConversion
771 explicit VectorReductionOpConversion(
const LLVMTypeConverter &typeConv,
772 bool reassociateFPRed)
773 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
774 reassociateFPReductions(reassociateFPRed) {}
777 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
778 ConversionPatternRewriter &rewriter)
const override {
779 auto kind = reductionOp.getKind();
780 Type eltType = reductionOp.getDest().getType();
781 Type llvmType = typeConverter->convertType(eltType);
782 Value operand = adaptor.getVector();
783 Value acc = adaptor.getAcc();
784 Location loc = reductionOp.getLoc();
790 case vector::CombiningKind::ADD:
792 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
794 rewriter, loc, llvmType, operand, acc);
796 case vector::CombiningKind::MUL:
798 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
800 rewriter, loc, llvmType, operand, acc);
802 case vector::CombiningKind::MINUI:
803 result = createIntegerReductionComparisonOpLowering<
804 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
805 LLVM::ICmpPredicate::ule);
807 case vector::CombiningKind::MINSI:
808 result = createIntegerReductionComparisonOpLowering<
809 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
810 LLVM::ICmpPredicate::sle);
812 case vector::CombiningKind::MAXUI:
813 result = createIntegerReductionComparisonOpLowering<
814 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
815 LLVM::ICmpPredicate::uge);
817 case vector::CombiningKind::MAXSI:
818 result = createIntegerReductionComparisonOpLowering<
819 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
820 LLVM::ICmpPredicate::sge);
822 case vector::CombiningKind::AND:
824 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
826 rewriter, loc, llvmType, operand, acc);
828 case vector::CombiningKind::OR:
830 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
832 rewriter, loc, llvmType, operand, acc);
834 case vector::CombiningKind::XOR:
836 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
838 rewriter, loc, llvmType, operand, acc);
843 rewriter.replaceOp(reductionOp,
result);
848 if (!isa<FloatType>(eltType))
851 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
852 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
853 reductionOp.getContext(),
855 fmf = LLVM::FastmathFlagsAttr::get(
856 reductionOp.getContext(),
857 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
858 : LLVM::FastmathFlags::none));
862 if (kind == vector::CombiningKind::ADD) {
863 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
864 ReductionNeutralZero>(
865 rewriter, loc, llvmType, operand, acc, fmf);
866 }
else if (kind == vector::CombiningKind::MUL) {
867 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
868 ReductionNeutralFPOne>(
869 rewriter, loc, llvmType, operand, acc, fmf);
870 }
else if (kind == vector::CombiningKind::MINIMUMF) {
872 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
873 rewriter, loc, llvmType, operand, acc, fmf);
874 }
else if (kind == vector::CombiningKind::MAXIMUMF) {
876 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
877 rewriter, loc, llvmType, operand, acc, fmf);
878 }
else if (kind == vector::CombiningKind::MINNUMF) {
879 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
880 rewriter, loc, llvmType, operand, acc, fmf);
881 }
else if (kind == vector::CombiningKind::MAXNUMF) {
882 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
883 rewriter, loc, llvmType, operand, acc, fmf);
888 rewriter.replaceOp(reductionOp,
result);
893 const bool reassociateFPReductions;
904template <
class MaskedOp>
905class VectorMaskOpConversionBase
908 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
911 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
912 ConversionPatternRewriter &rewriter)
const final {
914 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
917 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
921 virtual LogicalResult
922 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
923 vector::MaskableOpInterface maskableOp,
924 ConversionPatternRewriter &rewriter)
const = 0;
927class MaskedReductionOpConversion
928 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
931 using VectorMaskOpConversionBase<
932 vector::ReductionOp>::VectorMaskOpConversionBase;
934 LogicalResult matchAndRewriteMaskableOp(
935 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
936 ConversionPatternRewriter &rewriter)
const override {
937 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
938 auto kind = reductionOp.getKind();
939 Type eltType = reductionOp.getDest().getType();
940 Type llvmType = typeConverter->convertType(eltType);
941 Value operand = reductionOp.getVector();
942 Value acc = reductionOp.getAcc();
943 Location loc = reductionOp.getLoc();
945 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
946 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
947 reductionOp.getContext(),
952 case vector::CombiningKind::ADD:
953 result = lowerPredicatedReductionWithStartValue<
954 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
955 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
958 case vector::CombiningKind::MUL:
959 result = lowerPredicatedReductionWithStartValue<
960 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
961 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
964 case vector::CombiningKind::MINUI:
965 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
966 ReductionNeutralUIntMax>(
967 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
969 case vector::CombiningKind::MINSI:
970 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
971 ReductionNeutralSIntMax>(
972 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
974 case vector::CombiningKind::MAXUI:
975 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
976 ReductionNeutralUIntMin>(
977 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
979 case vector::CombiningKind::MAXSI:
980 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
981 ReductionNeutralSIntMin>(
982 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
984 case vector::CombiningKind::AND:
985 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
986 ReductionNeutralAllOnes>(
987 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
989 case vector::CombiningKind::OR:
990 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
991 ReductionNeutralZero>(
992 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
994 case vector::CombiningKind::XOR:
995 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
996 ReductionNeutralZero>(
997 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
999 case vector::CombiningKind::MINNUMF:
1000 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1001 ReductionNeutralFPMax>(
1002 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1004 case vector::CombiningKind::MAXNUMF:
1005 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1006 ReductionNeutralFPMin>(
1007 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1009 case CombiningKind::MAXIMUMF:
1010 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1011 MaskNeutralFMaximum>(
1012 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1014 case CombiningKind::MINIMUMF:
1015 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1016 MaskNeutralFMinimum>(
1017 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1022 rewriter.replaceOp(maskOp,
result);
1027class VectorShuffleOpConversion
1030 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
1033 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1034 ConversionPatternRewriter &rewriter)
const override {
1035 auto loc = shuffleOp->getLoc();
1036 auto v1Type = shuffleOp.getV1VectorType();
1037 auto v2Type = shuffleOp.getV2VectorType();
1038 auto vectorType = shuffleOp.getResultVectorType();
1039 Type llvmType = typeConverter->convertType(vectorType);
1040 ArrayRef<int64_t> mask = shuffleOp.getMask();
1047 int64_t rank = vectorType.getRank();
1049 bool wellFormed0DCase =
1050 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1051 bool wellFormedNDCase =
1052 v1Type.getRank() == rank && v2Type.getRank() == rank;
1053 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1058 if (rank <= 1 && v1Type == v2Type) {
1059 Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1060 rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1061 llvm::to_vector_of<int32_t>(mask));
1062 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1067 int64_t v1Dim = v1Type.getDimSize(0);
1069 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1070 eltType = arrayType.getElementType();
1072 eltType = cast<VectorType>(llvmType).getElementType();
1073 Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1075 for (int64_t extPos : mask) {
1076 Value value = adaptor.getV1();
1077 if (extPos >= v1Dim) {
1079 value = adaptor.getV2();
1081 Value extract =
extractOne(rewriter, *getTypeConverter(), loc, value,
1082 eltType, rank, extPos);
1083 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1084 llvmType, rank, insPos++);
1086 rewriter.replaceOp(shuffleOp, insert);
1091class VectorExtractOpConversion
1094 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1097 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1098 ConversionPatternRewriter &rewriter)
const override {
1099 auto loc = extractOp->getLoc();
1100 auto resultType = extractOp.getResult().getType();
1101 auto llvmResultType = typeConverter->convertType(resultType);
1103 if (!llvmResultType)
1107 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1121 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1125 bool extractsScalar =
static_cast<int64_t
>(positionVec.size()) ==
1126 extractOp.getSourceVectorType().getRank();
1130 if (extractOp.getSourceVectorType().getRank() == 0) {
1131 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1132 positionVec.push_back(rewriter.getZeroAttr(idxType));
1135 Value extracted = adaptor.getSource();
1136 if (extractsAggregate) {
1137 ArrayRef<OpFoldResult> position(positionVec);
1138 if (extractsScalar) {
1142 position = position.drop_back();
1145 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1148 extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1152 if (extractsScalar) {
1153 extracted = LLVM::ExtractElementOp::create(
1154 rewriter, loc, extracted,
1158 rewriter.replaceOp(extractOp, extracted);
1179 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1182 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1183 ConversionPatternRewriter &rewriter)
const override {
1184 VectorType vType = fmaOp.getVectorType();
1185 if (vType.getRank() > 1)
1188 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1189 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1194class VectorInsertOpConversion
1197 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1200 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1201 ConversionPatternRewriter &rewriter)
const override {
1202 auto loc = insertOp->getLoc();
1203 auto destVectorType = insertOp.getDestVectorType();
1204 auto llvmResultType = typeConverter->convertType(destVectorType);
1206 if (!llvmResultType)
1210 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1232 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1234 bool insertIntoInnermostDim =
1235 static_cast<int64_t
>(positionVec.size()) == destVectorType.getRank();
1237 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1238 positionVec.begin(),
1239 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1240 OpFoldResult positionOfScalarWithin1DVector;
1241 if (destVectorType.getRank() == 0) {
1244 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1245 positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1246 }
else if (insertIntoInnermostDim) {
1247 positionOfScalarWithin1DVector = positionVec.back();
1253 Value sourceAggregate = adaptor.getValueToStore();
1254 if (insertIntoInnermostDim) {
1257 if (isNestedAggregate) {
1260 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1261 llvm::IsaPred<Attribute>)) {
1265 sourceAggregate = LLVM::ExtractValueOp::create(
1266 rewriter, loc, adaptor.getDest(),
1271 sourceAggregate = adaptor.getDest();
1274 sourceAggregate = LLVM::InsertElementOp::create(
1275 rewriter, loc, sourceAggregate.
getType(), sourceAggregate,
1276 adaptor.getValueToStore(),
1280 Value
result = sourceAggregate;
1281 if (isNestedAggregate) {
1282 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1283 llvm::IsaPred<Attribute>)) {
1287 result = LLVM::InsertValueOp::create(
1288 rewriter, loc, adaptor.getDest(), sourceAggregate,
1292 rewriter.replaceOp(insertOp,
result);
1298struct VectorScalableInsertOpLowering
1300 using ConvertOpToLLVMPattern<
1301 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1304 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1305 ConversionPatternRewriter &rewriter)
const override {
1306 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1307 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1313struct VectorScalableExtractOpLowering
1315 using ConvertOpToLLVMPattern<
1316 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1319 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1320 ConversionPatternRewriter &rewriter)
const override {
1321 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1322 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1323 adaptor.getSource(), adaptor.getPos());
1356 setHasBoundedRewriteRecursion();
1359 LogicalResult matchAndRewrite(FMAOp op,
1360 PatternRewriter &rewriter)
const override {
1361 auto vType = op.getVectorType();
1362 if (vType.getRank() < 2)
1365 auto loc = op.getLoc();
1366 auto elemType = vType.getElementType();
1367 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1369 Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1370 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1371 Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1372 Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1373 Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1374 Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1375 desc = InsertOp::create(rewriter, loc, fma, desc, i);
1384static std::optional<SmallVector<int64_t, 4>>
1385computeContiguousStrides(MemRefType memRefType) {
1388 if (
failed(memRefType.getStridesAndOffset(strides, offset)))
1389 return std::nullopt;
1390 if (!strides.empty() && strides.back() != 1)
1391 return std::nullopt;
1393 if (memRefType.getLayout().isIdentity())
1400 auto sizes = memRefType.getShape();
1402 if (ShapedType::isDynamic(sizes[
index + 1]) ||
1403 ShapedType::isDynamic(strides[
index]) ||
1404 ShapedType::isDynamic(strides[
index + 1]))
1405 return std::nullopt;
1407 return std::nullopt;
1412class VectorTypeCastOpConversion
1415 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1418 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1419 ConversionPatternRewriter &rewriter)
const override {
1420 auto loc = castOp->getLoc();
1421 MemRefType sourceMemRefType =
1422 cast<MemRefType>(castOp.getOperand().getType());
1423 MemRefType targetMemRefType = castOp.getType();
1426 if (!sourceMemRefType.hasStaticShape() ||
1427 !targetMemRefType.hasStaticShape())
1430 auto llvmSourceDescriptorTy =
1431 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1432 if (!llvmSourceDescriptorTy)
1434 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1436 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1437 typeConverter->convertType(targetMemRefType));
1438 if (!llvmTargetDescriptorTy)
1442 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1445 auto targetStrides = computeContiguousStrides(targetMemRefType);
1449 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1452 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1455 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1457 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1458 desc.setAllocatedPtr(rewriter, loc, allocated);
1461 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1462 desc.setAlignedPtr(rewriter, loc, ptr);
1464 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1465 auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1466 desc.setOffset(rewriter, loc, zero);
1469 for (
const auto &indexedSize :
1470 llvm::enumerate(targetMemRefType.getShape())) {
1471 int64_t index = indexedSize.index();
1473 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1474 auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1475 desc.setSize(rewriter, loc, index, size);
1476 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1477 (*targetStrides)[index]);
1479 LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1480 desc.setStride(rewriter, loc, index, stride);
1483 rewriter.replaceOp(castOp, {desc});
1490class VectorCreateMaskOpConversion
1491 :
public OpConversionPattern<vector::CreateMaskOp> {
1493 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1494 bool enableIndexOpt)
1495 : OpConversionPattern<vector::CreateMaskOp>(context),
1496 force32BitVectorIndices(enableIndexOpt) {}
1499 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1500 ConversionPatternRewriter &rewriter)
const override {
1501 auto dstType = op.getType();
1502 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1504 IntegerType idxType =
1505 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1506 auto loc = op->getLoc();
1507 Value
indices = LLVM::StepVectorOp::create(
1512 adaptor.getOperands()[0]);
1513 Value bounds = BroadcastOp::create(rewriter, loc,
indices.getType(), bound);
1514 Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1516 rewriter.replaceOp(op, comp);
1521 const bool force32BitVectorIndices;
1525 SymbolTableCollection *symbolTables =
nullptr;
1528 explicit VectorPrintOpConversion(
1529 const LLVMTypeConverter &typeConverter,
1530 SymbolTableCollection *symbolTables =
nullptr)
1531 : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1532 symbolTables(symbolTables) {}
1548 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1549 ConversionPatternRewriter &rewriter)
const override {
1550 auto parent =
printOp->getParentOfType<ModuleOp>();
1556 if (
auto value = adaptor.getSource()) {
1558 if (isa<VectorType>(printType)) {
1562 if (
failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1566 auto punct =
printOp.getPunctuation();
1567 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1570 *stringLiteral, *getTypeConverter(),
1572 if (createResult.failed())
1575 }
else if (punct != PrintPunctuation::NoPunctuation) {
1576 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1578 case PrintPunctuation::Close:
1581 case PrintPunctuation::Open:
1584 case PrintPunctuation::Comma:
1587 case PrintPunctuation::NewLine:
1591 llvm_unreachable(
"unexpected punctuation");
1596 emitCall(rewriter,
printOp->getLoc(), op.value());
1604 enum class PrintConversion {
1613 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1614 ModuleOp parent, Location loc, Type printType,
1615 Value value)
const {
1616 if (typeConverter->convertType(printType) ==
nullptr)
1620 PrintConversion conversion = PrintConversion::None;
1621 FailureOr<Operation *> printer;
1627 conversion = PrintConversion::Bitcast16;
1630 conversion = PrintConversion::Bitcast16;
1634 }
else if (
auto intTy = dyn_cast<IntegerType>(printType)) {
1638 unsigned width = intTy.getWidth();
1639 if (intTy.isUnsigned()) {
1642 conversion = PrintConversion::ZeroExt64;
1649 assert(intTy.isSignless() || intTy.isSigned());
1654 conversion = PrintConversion::ZeroExt64;
1655 else if (width < 64)
1656 conversion = PrintConversion::SignExt64;
1663 }
else if (
auto floatTy = dyn_cast<FloatType>(printType)) {
1666 llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
1667 Value semValue = LLVM::ConstantOp::create(
1668 rewriter, loc, rewriter.getI32Type(),
1669 rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
1671 LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
1674 emitCall(rewriter, loc, printer.value(),
1683 switch (conversion) {
1684 case PrintConversion::ZeroExt64:
1685 value = arith::ExtUIOp::create(
1686 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1688 case PrintConversion::SignExt64:
1689 value = arith::ExtSIOp::create(
1690 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1692 case PrintConversion::Bitcast16:
1693 value = LLVM::BitcastOp::create(
1694 rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
1696 case PrintConversion::None:
1699 emitCall(rewriter, loc, printer.value(), value);
1704 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1706 LLVM::CallOp::create(rewriter, loc,
TypeRange(), SymbolRefAttr::get(ref),
1714struct VectorBroadcastScalarToLowRankLowering
1716 using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
1719 matchAndRewrite(vector::BroadcastOp
broadcast, OpAdaptor adaptor,
1720 ConversionPatternRewriter &rewriter)
const override {
1721 if (isa<VectorType>(
broadcast.getSourceType()))
1722 return rewriter.notifyMatchFailure(
1723 broadcast,
"broadcast from vector type not handled");
1726 if (resultType.getRank() > 1)
1727 return rewriter.notifyMatchFailure(
broadcast,
1728 "broadcast to 2+-d handled elsewhere");
1734 auto zero = LLVM::ConstantOp::create(
1736 typeConverter->convertType(rewriter.getIntegerType(32)),
1737 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1740 if (resultType.getRank() == 0) {
1741 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1742 broadcast, vectorType, poison, adaptor.getSource(), zero);
1747 LLVM::InsertElementOp::create(rewriter,
broadcast.
getLoc(), vectorType,
1748 poison, adaptor.getSource(), zero);
1752 SmallVector<int32_t> zeroValues(width, 0);
1755 auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
1766struct VectorBroadcastScalarToNdLowering
1768 using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
1771 matchAndRewrite(BroadcastOp
broadcast, OpAdaptor adaptor,
1772 ConversionPatternRewriter &rewriter)
const override {
1773 if (isa<VectorType>(
broadcast.getSourceType()))
1774 return rewriter.notifyMatchFailure(
1775 broadcast,
"broadcast from vector type not handled");
1778 if (resultType.getRank() <= 1)
1779 return rewriter.notifyMatchFailure(
1780 broadcast,
"broadcast to 1-d or 0-d handled elsewhere");
1784 auto vectorTypeInfo =
1786 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1787 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1788 if (!llvmNDVectorTy || !llvm1DVectorTy)
1792 Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1796 Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1797 auto zero = LLVM::ConstantOp::create(
1798 rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1799 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1800 Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1801 vdesc, adaptor.getSource(), zero);
1804 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1805 SmallVector<int32_t> zeroValues(width, 0);
1806 v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1810 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1811 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1820struct VectorInterleaveOpLowering
1825 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1826 ConversionPatternRewriter &rewriter)
const override {
1827 VectorType resultType = interleaveOp.getResultVectorType();
1829 if (resultType.getRank() != 1)
1830 return rewriter.notifyMatchFailure(interleaveOp,
1831 "InterleaveOp not rank 1");
1833 if (resultType.isScalable()) {
1834 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1835 interleaveOp, typeConverter->convertType(resultType),
1836 adaptor.getLhs(), adaptor.getRhs());
1843 int64_t resultVectorSize = resultType.getNumElements();
1844 SmallVector<int32_t> interleaveShuffleMask;
1845 interleaveShuffleMask.reserve(resultVectorSize);
1846 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1847 interleaveShuffleMask.push_back(i);
1848 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1850 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1851 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1852 interleaveShuffleMask);
1859struct VectorDeinterleaveOpLowering
1864 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1865 ConversionPatternRewriter &rewriter)
const override {
1866 VectorType resultType = deinterleaveOp.getResultVectorType();
1867 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1868 auto loc = deinterleaveOp.getLoc();
1872 if (resultType.getRank() != 1)
1873 return rewriter.notifyMatchFailure(deinterleaveOp,
1874 "DeinterleaveOp not rank 1");
1876 if (resultType.isScalable()) {
1877 const auto *llvmTypeConverter = this->getTypeConverter();
1878 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1879 auto packedOpResults =
1880 llvmTypeConverter->packOperationResults(deinterleaveResults);
1881 auto intrinsic = LLVM::vector_deinterleave2::create(
1882 rewriter, loc, packedOpResults, adaptor.getSource());
1884 auto evenResult = LLVM::ExtractValueOp::create(
1885 rewriter, loc, intrinsic->getResult(0), 0);
1886 auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1887 intrinsic->getResult(0), 1);
1889 rewriter.replaceOp(deinterleaveOp,
ValueRange{evenResult, oddResult});
1896 int64_t resultVectorSize = resultType.getNumElements();
1897 SmallVector<int32_t> evenShuffleMask;
1898 SmallVector<int32_t> oddShuffleMask;
1900 evenShuffleMask.reserve(resultVectorSize);
1901 oddShuffleMask.reserve(resultVectorSize);
1903 for (
int i = 0; i < sourceType.getNumElements(); ++i) {
1905 evenShuffleMask.push_back(i);
1907 oddShuffleMask.push_back(i);
1910 auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1911 auto evenShuffle = LLVM::ShuffleVectorOp::create(
1912 rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1913 auto oddShuffle = LLVM::ShuffleVectorOp::create(
1914 rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1916 rewriter.replaceOp(deinterleaveOp,
ValueRange{evenShuffle, oddShuffle});
1922struct VectorFromElementsLowering
1927 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1928 ConversionPatternRewriter &rewriter)
const override {
1929 Location loc = fromElementsOp.getLoc();
1930 VectorType vectorType = fromElementsOp.getType();
1934 if (vectorType.getRank() > 1)
1935 return rewriter.notifyMatchFailure(fromElementsOp,
1936 "rank > 1 vectors are not supported");
1937 Type llvmType = typeConverter->convertType(vectorType);
1938 Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1939 Value
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1940 for (
auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1942 LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1943 result = LLVM::InsertElementOp::create(rewriter, loc, llvmType,
result,
1946 rewriter.replaceOp(fromElementsOp,
result);
1952struct VectorToElementsLowering
1957 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1958 ConversionPatternRewriter &rewriter)
const override {
1959 Location loc = toElementsOp.getLoc();
1960 auto idxType = typeConverter->convertType(rewriter.getIndexType());
1961 Value source = adaptor.getSource();
1963 SmallVector<Value> results(toElementsOp->getNumResults());
1964 for (
auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1966 if (element.use_empty())
1969 auto constIdx = LLVM::ConstantOp::create(
1970 rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
1971 auto llvmType = typeConverter->convertType(element.getType());
1973 Value
result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1978 rewriter.replaceOp(toElementsOp, results);
1984struct VectorScalableStepOpLowering
1989 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1990 ConversionPatternRewriter &rewriter)
const override {
1991 auto resultType = cast<VectorType>(stepOp.getType());
1992 if (!resultType.isScalable()) {
1995 Type llvmType = typeConverter->convertType(stepOp.getType());
1996 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
2011class ContractionOpToMatmulOpLowering
2014 using MaskableOpRewritePattern::MaskableOpRewritePattern;
2016 ContractionOpToMatmulOpLowering(MLIRContext *context,
2017 PatternBenefit benefit = 100)
2018 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2021 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2022 PatternRewriter &rewriter)
const override;
2042FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2043 vector::ContractionOp op, MaskingOpInterface maskOp,
2049 auto iteratorTypes = op.getIteratorTypes().getValue();
2055 Type opResType = op.getType();
2056 VectorType vecType = dyn_cast<VectorType>(opResType);
2057 if (vecType && vecType.isScalable()) {
2062 Type elementType = op.getLhsType().getElementType();
2066 Type dstElementType = vecType ? vecType.getElementType() : opResType;
2067 if (elementType != dstElementType)
2072 MLIRContext *ctx = op.getContext();
2073 Location loc = op.getLoc();
2077 Value
lhs = op.getLhs();
2078 auto lhsMap = op.getIndexingMapsArray()[0];
2080 lhs = vector::TransposeOp::create(rew, loc,
lhs, ArrayRef<int64_t>{1, 0});
2085 Value
rhs = op.getRhs();
2086 auto rhsMap = op.getIndexingMapsArray()[1];
2088 rhs = vector::TransposeOp::create(rew, loc,
rhs, ArrayRef<int64_t>{1, 0});
2093 VectorType lhsType = cast<VectorType>(
lhs.getType());
2094 VectorType rhsType = cast<VectorType>(
rhs.getType());
2095 int64_t lhsRows = lhsType.getDimSize(0);
2096 int64_t lhsColumns = lhsType.getDimSize(1);
2097 int64_t rhsColumns = rhsType.getDimSize(1);
2099 Type flattenedLHSType =
2100 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2101 lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType,
lhs);
2103 Type flattenedRHSType =
2104 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2105 rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType,
rhs);
2107 Value
mul = LLVM::MatrixMultiplyOp::create(
2109 VectorType::get(lhsRows * rhsColumns,
2110 cast<VectorType>(
lhs.getType()).getElementType()),
2111 lhs,
rhs, lhsRows, lhsColumns, rhsColumns);
2113 mul = vector::ShapeCastOp::create(
2115 VectorType::get({lhsRows, rhsColumns},
2120 auto accMap = op.getIndexingMapsArray()[2];
2122 mul = vector::TransposeOp::create(rew, loc,
mul, ArrayRef<int64_t>{1, 0});
2124 llvm_unreachable(
"invalid contraction semantics");
2126 Value res = isa<IntegerType>(elementType)
2127 ?
static_cast<Value
>(
2128 arith::AddIOp::create(rew, loc, op.getAcc(),
mul))
2129 : static_cast<Value>(
2130 arith::AddFOp::create(rew, loc, op.getAcc(),
mul));
2148class TransposeOpToMatrixTransposeOpLowering
2149 :
public OpRewritePattern<vector::TransposeOp> {
2153 LogicalResult matchAndRewrite(vector::TransposeOp op,
2154 PatternRewriter &rewriter)
const override {
2155 auto loc = op.getLoc();
2157 Value input = op.getVector();
2158 VectorType inputType = op.getSourceVectorType();
2159 VectorType resType = op.getResultVectorType();
2161 if (inputType.isScalable())
2163 op,
"This lowering does not support scalable vectors");
2166 ArrayRef<int64_t> transp = op.getPermutation();
2168 if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2172 Type flattenedType =
2173 VectorType::get(resType.getNumElements(), resType.getElementType());
2175 vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2178 Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2179 matrix, rows, columns);
2189 patterns.
add<VectorFMAOpNDRewritePattern>(patterns.
getContext());
2194 patterns.
add<ContractionOpToMatmulOpLowering>(patterns.
getContext(), benefit);
2199 patterns.
add<TransposeOpToMatrixTransposeOpLowering>(patterns.
getContext(),
2206 bool reassociateFPReductions,
bool force32BitVectorIndices,
2207 bool useVectorAlignment) {
2210 patterns.
add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2211 patterns.
add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2212 patterns.
add<VectorLoadStoreConversion<vector::LoadOp>,
2213 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2214 VectorLoadStoreConversion<vector::StoreOp>,
2215 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2216 VectorGatherOpConversion, VectorScatterOpConversion>(
2217 converter, useVectorAlignment);
2218 patterns.
add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2219 VectorExtractOpConversion, VectorFMAOp1DConversion,
2220 VectorInsertOpConversion, VectorPrintOpConversion,
2221 VectorTypeCastOpConversion, VectorScaleOpConversion,
2222 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2223 VectorBroadcastScalarToLowRankLowering,
2224 VectorBroadcastScalarToNdLowering,
2225 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2226 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2227 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2228 VectorToElementsLowering, VectorScalableStepOpLowering>(
2235 void loadDependentDialects(
MLIRContext *context)
const final {
2236 context->loadDialect<LLVM::LLVMDialect>();
2241 void populateConvertToLLVMConversionPatterns(
2242 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
2243 RewritePatternSet &patterns)
const final {
2252 dialect->addInterfaces<VectorToLLVMDialectInterface>();
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.