30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
39 assert((tp.getRank() > 1) &&
"unlowerable vector type");
41 tp.getScalableDims().take_back());
49 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
52 auto constant = rewriter.
create<LLVM::ConstantOp>(
55 return rewriter.
create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58 return rewriter.
create<LLVM::InsertValueOp>(loc, val1, val2, pos);
64 Value val,
Type llvmType, int64_t rank, int64_t pos) {
67 auto constant = rewriter.
create<LLVM::ConstantOp>(
70 return rewriter.
create<LLVM::ExtractElementOp>(loc, llvmType, val,
73 return rewriter.
create<LLVM::ExtractValueOp>(loc, val, pos);
78 MemRefType memrefType,
unsigned &align) {
79 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
85 llvm::LLVMContext llvmContext;
104 MemRefType memRefType,
Value llvmMemref,
Value base,
105 Value index, uint64_t vLen) {
107 "unsupported memref type");
110 return rewriter.
create<LLVM::GEPOp>(
111 loc, ptrsType, typeConverter.
convertType(memRefType.getElementType()),
119 if (
auto attr = foldResult.dyn_cast<
Attribute>()) {
120 auto intAttr = cast<IntegerAttr>(attr);
121 return builder.
create<LLVM::ConstantOp>(loc, intAttr).getResult();
124 return foldResult.get<
Value>();
130 using VectorScaleOpConversion =
134 class VectorBitCastOpConversion
140 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
143 VectorType resultTy = bitCastOp.getResultVectorType();
144 if (resultTy.getRank() > 1)
146 Type newResultTy = typeConverter->convertType(resultTy);
148 adaptor.getOperands()[0]);
155 class VectorMatmulOpConversion
161 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
164 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
165 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
166 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
173 class VectorFlatTransposeOpConversion
179 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
182 transOp, typeConverter->convertType(transOp.getRes().getType()),
183 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
191 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
192 vector::LoadOpAdaptor adaptor,
193 VectorType vectorTy,
Value ptr,
unsigned align,
197 loadOp.getNontemporal());
200 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
201 vector::MaskedLoadOpAdaptor adaptor,
202 VectorType vectorTy,
Value ptr,
unsigned align,
205 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
208 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
209 vector::StoreOpAdaptor adaptor,
210 VectorType vectorTy,
Value ptr,
unsigned align,
214 storeOp.getNontemporal());
217 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
218 vector::MaskedStoreOpAdaptor adaptor,
219 VectorType vectorTy,
Value ptr,
unsigned align,
222 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
227 template <
class LoadOrStoreOp>
233 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
234 typename LoadOrStoreOp::Adaptor adaptor,
237 VectorType vectorTy = loadOrStoreOp.getVectorType();
238 if (vectorTy.getRank() > 1)
241 auto loc = loadOrStoreOp->getLoc();
242 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
250 auto vtype = cast<VectorType>(
251 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
252 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
253 adaptor.getIndices(), rewriter);
254 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
261 class VectorGatherOpConversion
267 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
269 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
270 assert(memRefType &&
"The base should be bufferized");
275 auto loc = gather->getLoc();
282 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
283 adaptor.getIndices(), rewriter);
284 Value base = adaptor.getBase();
286 auto llvmNDVectorTy = adaptor.getIndexVec().
getType();
288 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
289 auto vType = gather.getVectorType();
292 memRefType, base, ptr, adaptor.getIndexVec(),
293 vType.getDimSize(0));
296 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
302 auto callback = [align, memRefType, base, ptr, loc, &rewriter,
303 &typeConverter](
Type llvm1DVectorTy,
307 rewriter, loc, typeConverter, memRefType, base, ptr,
311 return rewriter.create<LLVM::masked_gather>(
312 loc, llvm1DVectorTy, ptrs, vectorOperands[1],
313 vectorOperands[2], rewriter.getI32IntegerAttr(align));
316 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
318 gather, vectorOperands, *getTypeConverter(), callback, rewriter);
323 class VectorScatterOpConversion
329 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
331 auto loc = scatter->getLoc();
332 MemRefType memRefType = scatter.getMemRefType();
343 VectorType vType = scatter.getVectorType();
344 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
345 adaptor.getIndices(), rewriter);
347 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
348 ptr, adaptor.getIndexVec(), vType.getDimSize(0));
352 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
359 class VectorExpandLoadOpConversion
365 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
367 auto loc = expand->getLoc();
368 MemRefType memRefType = expand.getMemRefType();
371 auto vtype = typeConverter->
convertType(expand.getVectorType());
372 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
373 adaptor.getIndices(), rewriter);
376 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
382 class VectorCompressStoreOpConversion
388 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
390 auto loc = compress->getLoc();
391 MemRefType memRefType = compress.getMemRefType();
394 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
395 adaptor.getIndices(), rewriter);
398 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
404 class ReductionNeutralZero {};
405 class ReductionNeutralIntOne {};
406 class ReductionNeutralFPOne {};
407 class ReductionNeutralAllOnes {};
408 class ReductionNeutralSIntMin {};
409 class ReductionNeutralUIntMin {};
410 class ReductionNeutralSIntMax {};
411 class ReductionNeutralUIntMax {};
412 class ReductionNeutralFPMin {};
413 class ReductionNeutralFPMax {};
416 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
419 return rewriter.
create<LLVM::ConstantOp>(loc, llvmType,
424 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
427 return rewriter.
create<LLVM::ConstantOp>(
432 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
435 return rewriter.
create<LLVM::ConstantOp>(
440 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
443 return rewriter.
create<LLVM::ConstantOp>(
450 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
453 return rewriter.
create<LLVM::ConstantOp>(
460 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
463 return rewriter.
create<LLVM::ConstantOp>(
470 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
473 return rewriter.
create<LLVM::ConstantOp>(
480 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
483 return rewriter.
create<LLVM::ConstantOp>(
490 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
493 auto floatType = cast<FloatType>(llvmType);
494 return rewriter.
create<LLVM::ConstantOp>(
497 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
502 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
505 auto floatType = cast<FloatType>(llvmType);
506 return rewriter.
create<LLVM::ConstantOp>(
509 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
515 template <
class ReductionNeutral>
522 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
531 VectorType vType = cast<VectorType>(llvmType);
532 auto vShape = vType.getShape();
533 assert(vShape.size() == 1 &&
"Unexpected multi-dim vector type");
535 return rewriter.
create<LLVM::ConstantOp>(
544 template <
class LLVMRedIntrinOp,
class ScalarOp>
545 static Value createIntegerReductionArithmeticOpLowering(
549 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
552 result = rewriter.
create<ScalarOp>(loc, accumulator, result);
560 template <
class LLVMRedIntrinOp>
561 static Value createIntegerReductionComparisonOpLowering(
563 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
564 Value result = rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
567 rewriter.
create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
568 result = rewriter.
create<LLVM::SelectOp>(loc, cmp, accumulator, result);
574 template <
typename Source>
575 struct VectorToScalarMapper;
577 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
578 using Type = LLVM::MaximumOp;
581 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
582 using Type = LLVM::MinimumOp;
585 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
586 using Type = LLVM::MaxNumOp;
589 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
590 using Type = LLVM::MinNumOp;
594 template <
class LLVMRedIntrinOp>
595 static Value createFPReductionComparisonOpLowering(
597 Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
599 rewriter.
create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
603 rewriter.
create<
typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
604 loc, result, accumulator);
611 class MaskNeutralFMaximum {};
612 class MaskNeutralFMinimum {};
616 getMaskNeutralValue(MaskNeutralFMaximum,
617 const llvm::fltSemantics &floatSemantics) {
618 return llvm::APFloat::getSmallest(floatSemantics,
true);
622 getMaskNeutralValue(MaskNeutralFMinimum,
623 const llvm::fltSemantics &floatSemantics) {
624 return llvm::APFloat::getLargest(floatSemantics,
false);
628 template <
typename MaskNeutral>
632 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
633 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
636 return rewriter.
create<LLVM::ConstantOp>(loc, vectorType, denseValue);
643 template <
class LLVMRedIntrinOp,
class MaskNeutral>
648 Value mask, LLVM::FastmathFlagsAttr fmf) {
649 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
650 rewriter, loc, llvmType, vectorOperand.
getType());
651 const Value selectedVectorByMask = rewriter.
create<LLVM::SelectOp>(
652 loc, mask, vectorOperand, vectorMaskNeutral);
653 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
654 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
657 template <
class LLVMRedIntrinOp,
class ReductionNeutral>
661 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
662 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
663 llvmType, accumulator);
664 return rewriter.
create<LLVMRedIntrinOp>(loc, llvmType,
672 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
677 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
678 llvmType, accumulator);
679 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
684 template <
class LLVMVPRedIntrinOp,
class ReductionNeutral>
685 static Value lowerPredicatedReductionWithStartValue(
688 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
689 llvmType, accumulator);
691 createVectorLengthValue(rewriter, loc, vectorOperand.
getType());
692 return rewriter.
create<LLVMVPRedIntrinOp>(loc, llvmType,
694 vectorOperand, mask, vectorLength);
697 template <
class LLVMIntVPRedIntrinOp,
class IntReductionNeutral,
698 class LLVMFPVPRedIntrinOp,
class FPReductionNeutral>
699 static Value lowerPredicatedReductionWithStartValue(
703 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
704 IntReductionNeutral>(
705 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
708 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
710 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
714 class VectorReductionOpConversion
718 bool reassociateFPRed)
720 reassociateFPReductions(reassociateFPRed) {}
723 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
725 auto kind = reductionOp.getKind();
726 Type eltType = reductionOp.getDest().getType();
728 Value operand = adaptor.getVector();
729 Value acc = adaptor.getAcc();
730 Location loc = reductionOp.getLoc();
736 case vector::CombiningKind::ADD:
738 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
740 rewriter, loc, llvmType, operand, acc);
742 case vector::CombiningKind::MUL:
744 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
746 rewriter, loc, llvmType, operand, acc);
749 result = createIntegerReductionComparisonOpLowering<
750 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
751 LLVM::ICmpPredicate::ule);
753 case vector::CombiningKind::MINSI:
754 result = createIntegerReductionComparisonOpLowering<
755 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
756 LLVM::ICmpPredicate::sle);
758 case vector::CombiningKind::MAXUI:
759 result = createIntegerReductionComparisonOpLowering<
760 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
761 LLVM::ICmpPredicate::uge);
763 case vector::CombiningKind::MAXSI:
764 result = createIntegerReductionComparisonOpLowering<
765 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
766 LLVM::ICmpPredicate::sge);
768 case vector::CombiningKind::AND:
770 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
772 rewriter, loc, llvmType, operand, acc);
774 case vector::CombiningKind::OR:
776 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
778 rewriter, loc, llvmType, operand, acc);
780 case vector::CombiningKind::XOR:
782 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
784 rewriter, loc, llvmType, operand, acc);
794 if (!isa<FloatType>(eltType))
797 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
799 reductionOp.getContext(),
802 reductionOp.getContext(),
803 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
804 : LLVM::FastmathFlags::none));
808 if (kind == vector::CombiningKind::ADD) {
809 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
810 ReductionNeutralZero>(
811 rewriter, loc, llvmType, operand, acc, fmf);
812 }
else if (kind == vector::CombiningKind::MUL) {
813 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
814 ReductionNeutralFPOne>(
815 rewriter, loc, llvmType, operand, acc, fmf);
816 }
else if (kind == vector::CombiningKind::MINIMUMF) {
818 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
819 rewriter, loc, llvmType, operand, acc, fmf);
820 }
else if (kind == vector::CombiningKind::MAXIMUMF) {
822 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
823 rewriter, loc, llvmType, operand, acc, fmf);
824 }
else if (kind == vector::CombiningKind::MINNUMF) {
825 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
826 rewriter, loc, llvmType, operand, acc, fmf);
827 }
else if (kind == vector::CombiningKind::MAXNUMF) {
828 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
829 rewriter, loc, llvmType, operand, acc, fmf);
838 const bool reassociateFPReductions;
849 template <
class MaskedOp>
850 class VectorMaskOpConversionBase
856 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
859 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
862 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
867 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
868 vector::MaskableOpInterface maskableOp,
872 class MaskedReductionOpConversion
873 :
public VectorMaskOpConversionBase<vector::ReductionOp> {
876 using VectorMaskOpConversionBase<
877 vector::ReductionOp>::VectorMaskOpConversionBase;
880 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
882 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
883 auto kind = reductionOp.getKind();
884 Type eltType = reductionOp.getDest().getType();
886 Value operand = reductionOp.getVector();
887 Value acc = reductionOp.getAcc();
888 Location loc = reductionOp.getLoc();
890 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
892 reductionOp.getContext(),
897 case vector::CombiningKind::ADD:
898 result = lowerPredicatedReductionWithStartValue<
899 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
900 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
903 case vector::CombiningKind::MUL:
904 result = lowerPredicatedReductionWithStartValue<
905 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
906 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
910 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
911 ReductionNeutralUIntMax>(
912 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
914 case vector::CombiningKind::MINSI:
915 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
916 ReductionNeutralSIntMax>(
917 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
919 case vector::CombiningKind::MAXUI:
920 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
921 ReductionNeutralUIntMin>(
922 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
924 case vector::CombiningKind::MAXSI:
925 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
926 ReductionNeutralSIntMin>(
927 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
929 case vector::CombiningKind::AND:
930 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
931 ReductionNeutralAllOnes>(
932 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
934 case vector::CombiningKind::OR:
935 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
936 ReductionNeutralZero>(
937 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
939 case vector::CombiningKind::XOR:
940 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
941 ReductionNeutralZero>(
942 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
944 case vector::CombiningKind::MINNUMF:
945 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
946 ReductionNeutralFPMax>(
947 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
949 case vector::CombiningKind::MAXNUMF:
950 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
951 ReductionNeutralFPMin>(
952 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
954 case CombiningKind::MAXIMUMF:
955 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
956 MaskNeutralFMaximum>(
957 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
959 case CombiningKind::MINIMUMF:
960 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
961 MaskNeutralFMinimum>(
962 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
967 rewriter.replaceOp(maskOp, result);
972 class VectorShuffleOpConversion
978 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
980 auto loc = shuffleOp->getLoc();
981 auto v1Type = shuffleOp.getV1VectorType();
982 auto v2Type = shuffleOp.getV2VectorType();
983 auto vectorType = shuffleOp.getResultVectorType();
985 auto maskArrayAttr = shuffleOp.getMask();
992 int64_t rank = vectorType.getRank();
994 bool wellFormed0DCase =
995 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
996 bool wellFormedNDCase =
997 v1Type.getRank() == rank && v2Type.getRank() == rank;
998 assert((wellFormed0DCase || wellFormedNDCase) &&
"op is not well-formed");
1003 if (rank <= 1 && v1Type == v2Type) {
1004 Value llvmShuffleOp = rewriter.
create<LLVM::ShuffleVectorOp>(
1005 loc, adaptor.getV1(), adaptor.getV2(),
1006 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1007 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
1012 int64_t v1Dim = v1Type.getDimSize(0);
1014 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1015 eltType = arrayType.getElementType();
1017 eltType = cast<VectorType>(llvmType).getElementType();
1018 Value insert = rewriter.
create<LLVM::UndefOp>(loc, llvmType);
1021 int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1022 Value value = adaptor.getV1();
1023 if (extPos >= v1Dim) {
1025 value = adaptor.getV2();
1028 eltType, rank, extPos);
1029 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1030 llvmType, rank, insPos++);
1037 class VectorExtractElementOpConversion
1044 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1046 auto vectorType = extractEltOp.getSourceVectorType();
1047 auto llvmType = typeConverter->
convertType(vectorType.getElementType());
1053 if (vectorType.getRank() == 0) {
1054 Location loc = extractEltOp.getLoc();
1056 auto zero = rewriter.
create<LLVM::ConstantOp>(
1060 extractEltOp, llvmType, adaptor.getVector(), zero);
1065 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1070 class VectorExtractOpConversion
1076 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1078 auto loc = extractOp->getLoc();
1079 auto resultType = extractOp.getResult().getType();
1080 auto llvmResultType = typeConverter->
convertType(resultType);
1082 if (!llvmResultType)
1086 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1090 if (position.empty()) {
1091 rewriter.
replaceOp(extractOp, adaptor.getVector());
1096 if (isa<VectorType>(resultType)) {
1097 if (extractOp.hasDynamicPosition())
1100 Value extracted = rewriter.
create<LLVM::ExtractValueOp>(
1102 rewriter.
replaceOp(extractOp, extracted);
1107 Value extracted = adaptor.getVector();
1108 if (position.size() > 1) {
1109 if (extractOp.hasDynamicPosition())
1114 extracted = rewriter.
create<LLVM::ExtractValueOp>(loc, extracted,
1145 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1147 VectorType vType = fmaOp.getVectorType();
1148 if (vType.getRank() > 1)
1152 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1157 class VectorInsertElementOpConversion
1163 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1165 auto vectorType = insertEltOp.getDestVectorType();
1166 auto llvmType = typeConverter->
convertType(vectorType);
1172 if (vectorType.getRank() == 0) {
1173 Location loc = insertEltOp.getLoc();
1175 auto zero = rewriter.
create<LLVM::ConstantOp>(
1179 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1184 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1185 adaptor.getPosition());
1190 class VectorInsertOpConversion
1196 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1198 auto loc = insertOp->getLoc();
1199 auto sourceType = insertOp.getSourceType();
1200 auto destVectorType = insertOp.getDestVectorType();
1201 auto llvmResultType = typeConverter->
convertType(destVectorType);
1203 if (!llvmResultType)
1207 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1212 if (position.empty()) {
1213 rewriter.
replaceOp(insertOp, adaptor.getSource());
1218 if (isa<VectorType>(sourceType)) {
1219 if (insertOp.hasDynamicPosition())
1222 Value inserted = rewriter.
create<LLVM::InsertValueOp>(
1223 loc, adaptor.getDest(), adaptor.getSource(),
getAsIntegers(position));
1229 Value extracted = adaptor.getDest();
1230 auto oneDVectorType = destVectorType;
1231 if (position.size() > 1) {
1232 if (insertOp.hasDynamicPosition())
1236 extracted = rewriter.
create<LLVM::ExtractValueOp>(
1241 Value inserted = rewriter.
create<LLVM::InsertElementOp>(
1242 loc, typeConverter->
convertType(oneDVectorType), extracted,
1243 adaptor.getSource(),
getAsLLVMValue(rewriter, loc, position.back()));
1246 if (position.size() > 1) {
1247 if (insertOp.hasDynamicPosition())
1250 inserted = rewriter.
create<LLVM::InsertValueOp>(
1251 loc, adaptor.getDest(), inserted,
1261 struct VectorScalableInsertOpLowering
1267 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1270 insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1276 struct VectorScalableExtractOpLowering
1282 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1285 extOp, typeConverter->
convertType(extOp.getResultVectorType()),
1286 adaptor.getSource(), adaptor.getPos());
1319 setHasBoundedRewriteRecursion();
1324 auto vType = op.getVectorType();
1325 if (vType.getRank() < 2)
1329 auto elemType = vType.getElementType();
1332 Value desc = rewriter.
create<vector::SplatOp>(loc, vType, zero);
1333 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1334 Value extrLHS = rewriter.
create<ExtractOp>(loc, op.getLhs(), i);
1335 Value extrRHS = rewriter.
create<ExtractOp>(loc, op.getRhs(), i);
1336 Value extrACC = rewriter.
create<ExtractOp>(loc, op.getAcc(), i);
1337 Value fma = rewriter.
create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1338 desc = rewriter.
create<InsertOp>(loc, fma, desc, i);
1347 static std::optional<SmallVector<int64_t, 4>>
1348 computeContiguousStrides(MemRefType memRefType) {
1352 return std::nullopt;
1353 if (!strides.empty() && strides.back() != 1)
1354 return std::nullopt;
1356 if (memRefType.getLayout().isIdentity())
1363 auto sizes = memRefType.getShape();
1364 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
1365 if (ShapedType::isDynamic(sizes[index + 1]) ||
1366 ShapedType::isDynamic(strides[index]) ||
1367 ShapedType::isDynamic(strides[index + 1]))
1368 return std::nullopt;
1369 if (strides[index] != strides[index + 1] * sizes[index + 1])
1370 return std::nullopt;
1375 class VectorTypeCastOpConversion
1381 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1383 auto loc = castOp->getLoc();
1384 MemRefType sourceMemRefType =
1385 cast<MemRefType>(castOp.getOperand().getType());
1386 MemRefType targetMemRefType = castOp.getType();
1389 if (!sourceMemRefType.hasStaticShape() ||
1390 !targetMemRefType.hasStaticShape())
1393 auto llvmSourceDescriptorTy =
1394 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1395 if (!llvmSourceDescriptorTy)
1399 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1401 if (!llvmTargetDescriptorTy)
1405 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1408 auto targetStrides = computeContiguousStrides(targetMemRefType);
1412 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1420 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1421 desc.setAllocatedPtr(rewriter, loc, allocated);
1424 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1425 desc.setAlignedPtr(rewriter, loc, ptr);
1428 auto zero = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, attr);
1429 desc.setOffset(rewriter, loc, zero);
1432 for (
const auto &indexedSize :
1434 int64_t index = indexedSize.index();
1437 auto size = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1438 desc.setSize(rewriter, loc, index, size);
1440 (*targetStrides)[index]);
1441 auto stride = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1442 desc.setStride(rewriter, loc, index, stride);
1452 class VectorCreateMaskOpRewritePattern
1455 explicit VectorCreateMaskOpRewritePattern(
MLIRContext *context,
1456 bool enableIndexOpt)
1458 force32BitVectorIndices(enableIndexOpt) {}
1462 auto dstType = op.getType();
1463 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1465 IntegerType idxType =
1468 Value indices = rewriter.
create<LLVM::StepVectorOp>(
1474 Value comp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1481 const bool force32BitVectorIndices;
1502 matchAndRewrite(vector::PrintOp
printOp, OpAdaptor adaptor,
1504 auto parent =
printOp->getParentOfType<ModuleOp>();
1510 if (
auto value = adaptor.getSource()) {
1520 auto punct =
printOp.getPunctuation();
1521 if (
auto stringLiteral =
printOp.getStringLiteral()) {
1523 *stringLiteral, *getTypeConverter(),
1525 }
else if (punct != PrintPunctuation::NoPunctuation) {
1526 emitCall(rewriter,
printOp->getLoc(), [&] {
1528 case PrintPunctuation::Close:
1529 return LLVM::lookupOrCreatePrintCloseFn(parent);
1530 case PrintPunctuation::Open:
1531 return LLVM::lookupOrCreatePrintOpenFn(parent);
1532 case PrintPunctuation::Comma:
1533 return LLVM::lookupOrCreatePrintCommaFn(parent);
1534 case PrintPunctuation::NewLine:
1535 return LLVM::lookupOrCreatePrintNewlineFn(parent);
1537 llvm_unreachable(
"unexpected punctuation");
1547 enum class PrintConversion {
1558 Value value)
const {
1570 conversion = PrintConversion::Bitcast16;
1573 conversion = PrintConversion::Bitcast16;
1577 }
else if (
auto intTy = dyn_cast<IntegerType>(
printType)) {
1581 unsigned width = intTy.getWidth();
1582 if (intTy.isUnsigned()) {
1585 conversion = PrintConversion::ZeroExt64;
1591 assert(intTy.isSignless() || intTy.isSigned());
1596 conversion = PrintConversion::ZeroExt64;
1597 else if (width < 64)
1598 conversion = PrintConversion::SignExt64;
1608 switch (conversion) {
1609 case PrintConversion::ZeroExt64:
1610 value = rewriter.
create<arith::ExtUIOp>(
1613 case PrintConversion::SignExt64:
1614 value = rewriter.
create<arith::ExtSIOp>(
1617 case PrintConversion::Bitcast16:
1618 value = rewriter.
create<LLVM::BitcastOp>(
1624 emitCall(rewriter, loc, printer, value);
1642 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1644 VectorType resultType = cast<VectorType>(splatOp.getType());
1645 if (resultType.getRank() > 1)
1649 auto vectorType = typeConverter->
convertType(splatOp.getType());
1650 Value undef = rewriter.
create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1651 auto zero = rewriter.
create<LLVM::ConstantOp>(
1657 if (resultType.getRank() == 0) {
1659 splatOp, vectorType, undef, adaptor.getInput(), zero);
1664 auto v = rewriter.
create<LLVM::InsertElementOp>(
1665 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1667 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1684 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1686 VectorType resultType = splatOp.getType();
1687 if (resultType.getRank() <= 1)
1691 auto loc = splatOp.getLoc();
1692 auto vectorTypeInfo =
1694 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1695 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1696 if (!llvmNDVectorTy || !llvm1DVectorTy)
1700 Value desc = rewriter.
create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1704 Value vdesc = rewriter.
create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1705 auto zero = rewriter.
create<LLVM::ConstantOp>(
1708 Value v = rewriter.
create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1709 adaptor.getInput(), zero);
1712 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1714 v = rewriter.
create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1719 desc = rewriter.
create<LLVM::InsertValueOp>(loc, desc, v, position);
1728 struct VectorInterleaveOpLowering
1733 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1735 VectorType resultType = interleaveOp.getResultVectorType();
1737 if (resultType.getRank() != 1)
1739 "InterleaveOp not rank 1");
1741 if (resultType.isScalable()) {
1743 interleaveOp, typeConverter->
convertType(resultType),
1744 adaptor.getLhs(), adaptor.getRhs());
1751 int64_t resultVectorSize = resultType.getNumElements();
1753 interleaveShuffleMask.reserve(resultVectorSize);
1754 for (
int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1755 interleaveShuffleMask.push_back(i);
1756 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1759 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1760 interleaveShuffleMask);
1770 bool reassociateFPReductions,
bool force32BitVectorIndices) {
1772 patterns.
add<VectorFMAOpNDRewritePattern>(ctx);
1774 patterns.
add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1775 patterns.
add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1776 patterns.
add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1777 VectorExtractElementOpConversion, VectorExtractOpConversion,
1778 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1779 VectorInsertOpConversion, VectorPrintOpConversion,
1780 VectorTypeCastOpConversion, VectorScaleOpConversion,
1781 VectorLoadStoreConversion<vector::LoadOp>,
1782 VectorLoadStoreConversion<vector::MaskedLoadOp>,
1783 VectorLoadStoreConversion<vector::StoreOp>,
1784 VectorLoadStoreConversion<vector::MaskedStoreOp>,
1785 VectorGatherOpConversion, VectorScatterOpConversion,
1786 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1787 VectorSplatOpLowering, VectorSplatNdOpLowering,
1788 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1789 MaskedReductionOpConversion, VectorInterleaveOpLowering>(
1797 patterns.
add<VectorMatmulOpConversion>(converter);
1798 patterns.
add<VectorFlatTransposeOpConversion>(converter);
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, uint64_t vLen)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
static VectorType reducedVectorTypeBack(VectorType tp)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp)
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp)
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...