28 assert((tp.getRank() > 1) &&
"unlowerable vector type");
29 unsigned numScalableDims = tp.getNumScalableDims();
30 if (tp.getShape().size() == numScalableDims)
32 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
38 assert((tp.getRank() > 1) &&
"unlowerable vector type");
39 unsigned numScalableDims = tp.getNumScalableDims();
40 if (numScalableDims > 0)
42 return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
51 assert(rank > 0 &&
"0-D vector corner case should have been handled already");
54 auto constant = rewriter.
create<LLVM::ConstantOp>(
57 return rewriter.
create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
60 return rewriter.
create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
67 Value val,
Type llvmType, int64_t rank, int64_t pos) {
70 auto constant = rewriter.
create<LLVM::ConstantOp>(
73 return rewriter.
create<LLVM::ExtractElementOp>(loc, llvmType, val,
76 return rewriter.
create<LLVM::ExtractValueOp>(loc, llvmType, val,
82 MemRefType memrefType,
unsigned &align) {
83 Type elementTy = typeConverter.
convertType(memrefType.getElementType());
89 llvm::LLVMContext llvmContext;
99 Value index, MemRefType memRefType,
100 VectorType vType,
Value &ptrs) {
104 if (
failed(successStrides) || strides.back() != 1 ||
105 memRefType.getMemorySpaceAsInt() != 0)
109 ptrs = rewriter.
create<LLVM::GEPOp>(loc, ptrsType, base, index);
116 Value ptr, MemRefType memRefType,
Type vt) {
118 return rewriter.
create<LLVM::BitcastOp>(loc, pType, ptr);
124 using VectorScaleOpConversion =
128 class VectorBitCastOpConversion
134 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
137 VectorType resultTy = bitCastOp.getResultVectorType();
138 if (resultTy.getRank() > 1)
140 Type newResultTy = typeConverter->convertType(resultTy);
142 adaptor.getOperands()[0]);
149 class VectorMatmulOpConversion
155 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
158 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
159 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
160 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
167 class VectorFlatTransposeOpConversion
173 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
176 transOp, typeConverter->convertType(transOp.getRes().getType()),
177 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
185 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
186 vector::LoadOpAdaptor adaptor,
187 VectorType vectorTy,
Value ptr,
unsigned align,
192 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
193 vector::MaskedLoadOpAdaptor adaptor,
194 VectorType vectorTy,
Value ptr,
unsigned align,
197 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
200 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
201 vector::StoreOpAdaptor adaptor,
202 VectorType vectorTy,
Value ptr,
unsigned align,
208 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
209 vector::MaskedStoreOpAdaptor adaptor,
210 VectorType vectorTy,
Value ptr,
unsigned align,
213 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
218 template <
class LoadOrStoreOp,
class LoadOrStoreOpAdaptor>
224 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
225 typename LoadOrStoreOp::Adaptor adaptor,
228 VectorType vectorTy = loadOrStoreOp.getVectorType();
229 if (vectorTy.getRank() > 1)
232 auto loc = loadOrStoreOp->getLoc();
233 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
241 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
242 .
template cast<VectorType>();
243 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
244 adaptor.getIndices(), rewriter);
247 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
253 class VectorGatherOpConversion
259 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
261 auto loc = gather->getLoc();
262 MemRefType memRefType = gather.getMemRefType();
271 VectorType vType = gather.getVectorType();
272 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
273 adaptor.getIndices(), rewriter);
275 adaptor.getIndexVec(), memRefType, vType, ptrs)))
280 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
287 class VectorScatterOpConversion
293 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
295 auto loc = scatter->getLoc();
296 MemRefType memRefType = scatter.getMemRefType();
305 VectorType vType = scatter.getVectorType();
306 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
307 adaptor.getIndices(), rewriter);
309 adaptor.getIndexVec(), memRefType, vType, ptrs)))
314 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
321 class VectorExpandLoadOpConversion
327 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
329 auto loc = expand->getLoc();
330 MemRefType memRefType = expand.getMemRefType();
333 auto vtype = typeConverter->convertType(expand.getVectorType());
334 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
335 adaptor.getIndices(), rewriter);
338 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
344 class VectorCompressStoreOpConversion
350 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
352 auto loc = compress->getLoc();
353 MemRefType memRefType = compress.getMemRefType();
356 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
357 adaptor.getIndices(), rewriter);
360 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
369 template <
class VectorOp,
class ScalarOp>
370 static Value createIntegerReductionArithmeticOpLowering(
373 Value result = rewriter.
create<VectorOp>(loc, llvmType, vectorOperand);
375 result = rewriter.
create<ScalarOp>(loc, accumulator, result);
383 template <
class VectorOp>
384 static Value createIntegerReductionComparisonOpLowering(
386 Value vectorOperand,
Value accumulator, LLVM::ICmpPredicate predicate) {
387 Value result = rewriter.
create<VectorOp>(loc, llvmType, vectorOperand);
390 rewriter.
create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
391 result = rewriter.
create<LLVM::SelectOp>(loc, cmp, accumulator, result);
397 class VectorReductionOpConversion
401 bool reassociateFPRed)
403 reassociateFPReductions(reassociateFPRed) {}
406 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
408 auto kind = reductionOp.getKind();
409 Type eltType = reductionOp.getDest().getType();
410 Type llvmType = typeConverter->convertType(eltType);
411 Value operand = adaptor.getVector();
412 Value acc = adaptor.getAcc();
413 Location loc = reductionOp.getLoc();
418 case vector::CombiningKind::ADD:
420 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
422 rewriter, loc, llvmType, operand, acc);
424 case vector::CombiningKind::MUL:
426 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
428 rewriter, loc, llvmType, operand, acc);
430 case vector::CombiningKind::MINUI:
431 result = createIntegerReductionComparisonOpLowering<
432 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
433 LLVM::ICmpPredicate::ule);
435 case vector::CombiningKind::MINSI:
436 result = createIntegerReductionComparisonOpLowering<
437 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
438 LLVM::ICmpPredicate::sle);
440 case vector::CombiningKind::MAXUI:
441 result = createIntegerReductionComparisonOpLowering<
442 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
443 LLVM::ICmpPredicate::uge);
445 case vector::CombiningKind::MAXSI:
446 result = createIntegerReductionComparisonOpLowering<
447 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
448 LLVM::ICmpPredicate::sge);
450 case vector::CombiningKind::AND:
452 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
454 rewriter, loc, llvmType, operand, acc);
456 case vector::CombiningKind::OR:
458 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
460 rewriter, loc, llvmType, operand, acc);
462 case vector::CombiningKind::XOR:
464 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
466 rewriter, loc, llvmType, operand, acc);
480 if (kind == vector::CombiningKind::ADD) {
482 Value acc = adaptor.getOperands().size() > 1
483 ? adaptor.getOperands()[1]
484 : rewriter.
create<LLVM::ConstantOp>(
485 reductionOp->getLoc(), llvmType,
488 reductionOp, llvmType, acc, operand,
490 }
else if (kind == vector::CombiningKind::MUL) {
492 Value acc = adaptor.getOperands().size() > 1
493 ? adaptor.getOperands()[1]
494 : rewriter.
create<LLVM::ConstantOp>(
495 reductionOp->getLoc(), llvmType,
498 reductionOp, llvmType, acc, operand,
500 }
else if (kind == vector::CombiningKind::MINF)
505 else if (kind == vector::CombiningKind::MAXF)
516 const bool reassociateFPReductions;
519 class VectorShuffleOpConversion
525 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
527 auto loc = shuffleOp->getLoc();
528 auto v1Type = shuffleOp.getV1VectorType();
529 auto v2Type = shuffleOp.getV2VectorType();
532 auto maskArrayAttr = shuffleOp.getMask();
540 assert(v1Type.getRank() == rank);
541 assert(v2Type.getRank() == rank);
542 int64_t v1Dim = v1Type.getDimSize(0);
546 if (rank == 1 && v1Type == v2Type) {
547 Value llvmShuffleOp = rewriter.
create<LLVM::ShuffleVectorOp>(
548 loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr);
549 rewriter.
replaceOp(shuffleOp, llvmShuffleOp);
556 eltType = arrayType.getElementType();
559 Value insert = rewriter.
create<LLVM::UndefOp>(loc, llvmType);
562 int64_t extPos = en.value().cast<IntegerAttr>().getInt();
564 if (extPos >= v1Dim) {
566 value = adaptor.getV2();
569 eltType, rank, extPos);
570 insert =
insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
571 llvmType, rank, insPos++);
578 class VectorExtractElementOpConversion
585 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
587 auto vectorType = extractEltOp.getVectorType();
588 auto llvmType = typeConverter->convertType(
vectorType.getElementType());
595 Location loc = extractEltOp.getLoc();
597 auto zero = rewriter.
create<LLVM::ConstantOp>(
598 loc, typeConverter->convertType(idxType),
601 extractEltOp, llvmType, adaptor.getVector(), zero);
606 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
611 class VectorExtractOpConversion
617 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
619 auto loc = extractOp->getLoc();
621 auto resultType = extractOp.getResult().getType();
622 auto llvmResultType = typeConverter->convertType(resultType);
623 auto positionArrayAttr = extractOp.getPosition();
630 if (positionArrayAttr.empty()) {
631 rewriter.
replaceOp(extractOp, adaptor.getVector());
636 if (resultType.isa<VectorType>()) {
637 Value extracted = rewriter.
create<LLVM::ExtractValueOp>(
638 loc, llvmResultType, adaptor.getVector(), positionArrayAttr);
639 rewriter.
replaceOp(extractOp, extracted);
644 auto *context = extractOp->getContext();
645 Value extracted = adaptor.getVector();
646 auto positionAttrs = positionArrayAttr.getValue();
647 if (positionAttrs.size() > 1) {
649 auto nMinusOnePositionAttrs =
650 ArrayAttr::get(context, positionAttrs.drop_back());
651 extracted = rewriter.
create<LLVM::ExtractValueOp>(
652 loc, typeConverter->convertType(oneDVectorType), extracted,
653 nMinusOnePositionAttrs);
657 auto position = positionAttrs.back().cast<IntegerAttr>();
658 auto i64Type = IntegerType::get(rewriter.
getContext(), 64);
659 auto constant = rewriter.
create<LLVM::ConstantOp>(loc, i64Type, position);
661 rewriter.
create<LLVM::ExtractElementOp>(loc, extracted, constant);
662 rewriter.
replaceOp(extractOp, extracted);
687 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
689 VectorType vType = fmaOp.getVectorType();
690 if (vType.getRank() != 1)
693 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
698 class VectorInsertElementOpConversion
704 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
706 auto vectorType = insertEltOp.getDestVectorType();
707 auto llvmType = typeConverter->convertType(
vectorType);
714 Location loc = insertEltOp.getLoc();
716 auto zero = rewriter.
create<LLVM::ConstantOp>(
717 loc, typeConverter->convertType(idxType),
720 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
725 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
726 adaptor.getPosition());
731 class VectorInsertOpConversion
737 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
739 auto loc = insertOp->getLoc();
740 auto sourceType = insertOp.getSourceType();
741 auto destVectorType = insertOp.getDestVectorType();
742 auto llvmResultType = typeConverter->convertType(destVectorType);
743 auto positionArrayAttr = insertOp.getPosition();
751 if (positionArrayAttr.empty()) {
752 rewriter.
replaceOp(insertOp, adaptor.getSource());
757 if (sourceType.isa<VectorType>()) {
758 Value inserted = rewriter.
create<LLVM::InsertValueOp>(
759 loc, llvmResultType, adaptor.getDest(), adaptor.getSource(),
766 auto *context = insertOp->getContext();
767 Value extracted = adaptor.getDest();
768 auto positionAttrs = positionArrayAttr.getValue();
769 auto position = positionAttrs.back().
cast<IntegerAttr>();
770 auto oneDVectorType = destVectorType;
771 if (positionAttrs.size() > 1) {
773 auto nMinusOnePositionAttrs =
774 ArrayAttr::get(context, positionAttrs.drop_back());
775 extracted = rewriter.
create<LLVM::ExtractValueOp>(
776 loc, typeConverter->convertType(oneDVectorType), extracted,
777 nMinusOnePositionAttrs);
781 auto i64Type = IntegerType::get(rewriter.
getContext(), 64);
782 auto constant = rewriter.
create<LLVM::ConstantOp>(loc, i64Type, position);
783 Value inserted = rewriter.
create<LLVM::InsertElementOp>(
784 loc, typeConverter->convertType(oneDVectorType), extracted,
785 adaptor.getSource(), constant);
788 if (positionAttrs.size() > 1) {
789 auto nMinusOnePositionAttrs =
790 ArrayAttr::get(context, positionAttrs.drop_back());
791 inserted = rewriter.
create<LLVM::InsertValueOp>(
792 loc, llvmResultType, adaptor.getDest(), inserted,
793 nMinusOnePositionAttrs);
829 setHasBoundedRewriteRecursion();
834 auto vType = op.getVectorType();
835 if (vType.getRank() < 2)
838 auto loc = op.getLoc();
839 auto elemType = vType.getElementType();
842 Value desc = rewriter.
create<vector::SplatOp>(loc, vType, zero);
843 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
844 Value extrLHS = rewriter.
create<ExtractOp>(loc, op.getLhs(), i);
845 Value extrRHS = rewriter.
create<ExtractOp>(loc, op.getRhs(), i);
846 Value extrACC = rewriter.
create<ExtractOp>(loc, op.getAcc(), i);
847 Value fma = rewriter.
create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
848 desc = rewriter.
create<InsertOp>(loc, fma, desc, i);
858 computeContiguousStrides(MemRefType memRefType) {
863 if (!strides.empty() && strides.back() != 1)
866 if (memRefType.getLayout().isIdentity())
873 auto sizes = memRefType.getShape();
874 for (
int index = 0, e = strides.size() - 1; index < e; ++index) {
875 if (ShapedType::isDynamic(sizes[index + 1]) ||
876 ShapedType::isDynamicStrideOrOffset(strides[index]) ||
877 ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
879 if (strides[index] != strides[index + 1] * sizes[index + 1])
885 class VectorTypeCastOpConversion
891 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
893 auto loc = castOp->getLoc();
894 MemRefType sourceMemRefType =
896 MemRefType targetMemRefType = castOp.getType();
899 if (!sourceMemRefType.hasStaticShape() ||
900 !targetMemRefType.hasStaticShape())
903 auto llvmSourceDescriptorTy =
905 if (!llvmSourceDescriptorTy)
909 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
911 if (!llvmTargetDescriptorTy)
915 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
918 auto targetStrides = computeContiguousStrides(targetMemRefType);
922 if (llvm::any_of(*targetStrides, [](int64_t stride) {
923 return ShapedType::isDynamicStrideOrOffset(stride);
927 auto int64Ty = IntegerType::get(rewriter.
getContext(), 64);
931 Type llvmTargetElementTy = desc.getElementPtrType();
933 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
935 rewriter.
create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
936 desc.setAllocatedPtr(rewriter, loc, allocated);
938 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
939 ptr = rewriter.
create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
940 desc.setAlignedPtr(rewriter, loc, ptr);
943 auto zero = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, attr);
944 desc.setOffset(rewriter, loc, zero);
947 for (
const auto &indexedSize :
949 int64_t index = indexedSize.index();
952 auto size = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
953 desc.setSize(rewriter, loc, index, size);
955 (*targetStrides)[index]);
956 auto stride = rewriter.
create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
957 desc.setStride(rewriter, loc, index, stride);
967 class VectorCreateMaskOpRewritePattern
970 explicit VectorCreateMaskOpRewritePattern(
MLIRContext *context,
973 force32BitVectorIndices(enableIndexOpt) {}
977 auto dstType = op.getType();
978 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
980 IntegerType idxType =
982 auto loc = op->getLoc();
983 Value indices = rewriter.
create<LLVM::StepVectorOp>(
989 Value comp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
996 const bool force32BitVectorIndices;
1016 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1020 if (typeConverter->convertType(printType) ==
nullptr)
1026 Type eltType = vectorType ? vectorType.getElementType() :
printType;
1028 if (eltType.isF32()) {
1031 }
else if (eltType.isF64()) {
1034 }
else if (eltType.isIndex()) {
1037 }
else if (
auto intTy = eltType.dyn_cast<IntegerType>()) {
1041 unsigned width = intTy.getWidth();
1042 if (intTy.isUnsigned()) {
1045 conversion = PrintConversion::ZeroExt64;
1047 printOp->getParentOfType<ModuleOp>());
1052 assert(intTy.isSignless() || intTy.isSigned());
1057 conversion = PrintConversion::ZeroExt64;
1058 else if (width < 64)
1059 conversion = PrintConversion::SignExt64;
1061 printOp->getParentOfType<ModuleOp>());
1071 int64_t rank = vectorType ? vectorType.getRank() : 0;
1073 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
1075 emitCall(rewriter, printOp->getLoc(),
1077 printOp->getParentOfType<ModuleOp>()));
1083 enum class PrintConversion {
1093 PrintConversion conversion)
const {
1097 assert(rank == 0 &&
"The scalar case expects rank == 0");
1098 switch (conversion) {
1099 case PrintConversion::ZeroExt64:
1100 value = rewriter.
create<arith::ExtUIOp>(
1101 loc, IntegerType::get(rewriter.
getContext(), 64), value);
1103 case PrintConversion::SignExt64:
1104 value = rewriter.
create<arith::ExtSIOp>(
1105 loc, IntegerType::get(rewriter.
getContext(), 64), value);
1110 emitCall(rewriter, loc, printer, value);
1114 emitCall(rewriter, loc,
1120 auto reducedType = vectorType.getElementType();
1121 auto llvmType = typeConverter->convertType(reducedType);
1122 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1123 for (int64_t d = 0; d < dim; ++d) {
1124 Value nestedVal =
extractOne(rewriter, *getTypeConverter(), loc, value,
1126 emitRanks(rewriter, op, nestedVal, reducedType, printer, 0,
1129 emitCall(rewriter, loc, printComma);
1137 int64_t dim = vectorType.getDimSize(0);
1138 for (int64_t d = 0; d < dim; ++d) {
1140 auto llvmType = typeConverter->convertType(reducedType);
1141 Value nestedVal =
extractOne(rewriter, *getTypeConverter(), loc, value,
1143 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1146 emitCall(rewriter, loc, printComma);
1148 emitCall(rewriter, loc,
1155 rewriter.
create<LLVM::CallOp>(loc,
TypeRange(), SymbolRefAttr::get(ref),
1166 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1168 VectorType resultType = splatOp.getType().cast<VectorType>();
1169 if (resultType.getRank() > 1)
1173 auto vectorType = typeConverter->convertType(splatOp.getType());
1175 auto zero = rewriter.
create<LLVM::ConstantOp>(
1181 if (resultType.getRank() == 0) {
1183 splatOp,
vectorType, undef, adaptor.getInput(), zero);
1188 auto v = rewriter.
create<LLVM::InsertElementOp>(
1189 splatOp.getLoc(),
vectorType, undef, adaptor.getInput(), zero);
1191 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1209 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1211 VectorType resultType = splatOp.getType();
1212 if (resultType.getRank() <= 1)
1216 auto loc = splatOp.getLoc();
1217 auto vectorTypeInfo =
1219 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1220 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1221 if (!llvmNDVectorTy || !llvm1DVectorTy)
1225 Value desc = rewriter.
create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1229 Value vdesc = rewriter.
create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1230 auto zero = rewriter.
create<LLVM::ConstantOp>(
1233 Value v = rewriter.
create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1234 adaptor.getInput(), zero);
1237 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1240 v = rewriter.
create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
1245 desc = rewriter.
create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
1258 bool reassociateFPReductions,
bool force32BitVectorIndices) {
1260 patterns.
add<VectorFMAOpNDRewritePattern>(ctx);
1262 patterns.
add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1263 patterns.
add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1265 .
add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1266 VectorExtractElementOpConversion, VectorExtractOpConversion,
1267 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1268 VectorInsertOpConversion, VectorPrintOpConversion,
1269 VectorTypeCastOpConversion, VectorScaleOpConversion,
1270 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1271 VectorLoadStoreConversion<vector::MaskedLoadOp,
1272 vector::MaskedLoadOpAdaptor>,
1273 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1274 VectorLoadStoreConversion<vector::MaskedStoreOp,
1275 vector::MaskedStoreOpAdaptor>,
1276 VectorGatherOpConversion, VectorScatterOpConversion,
1277 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1278 VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
1285 patterns.
add<VectorMatmulOpConversion>(converter);
1286 patterns.
add<VectorFlatTransposeOpConversion>(converter);
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter)
MLIRContext * getContext() const
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Operation is a basic unit of execution within MLIR.
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp)
Attribute getZeroAttr(Type type)
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
static Value insertOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
static VectorType reducedVectorTypeBack(VectorType tp)
static VectorType reducedVectorTypeFront(VectorType tp)
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp)
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI32IntegerAttr(int32_t value)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns)
Populate patterns with the following patterns.
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp)
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIR_CRUNNERUTILS_EXPORT void printComma()
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
const llvm::DataLayout & getDataLayout()
Returns the data layout to use during and after conversion.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
IntegerType getIntegerType(unsigned width)
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
This class provides an abstraction over the various different ranges of value types.
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Location getLoc()
The source location the operation was defined or derived from.
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, Value memref, Value base, Value index, MemRefType memRefType, VectorType vType, Value &ptrs)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp)
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Type getType() const
Return the type of this attribute.
static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp)
Type getType() const
Return the type of this value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Do not split vector transfer operations.
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr, MemRefType memRefType, Type vt)
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Conversion from types to the LLVM IR dialect.
BoolAttr getBoolAttr(bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayAttr)> fun)
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, llvm::Optional< unsigned > maxTransferRank=llvm::None)
Collect a set of transfer read/write lowering patterns.
This class implements a pattern rewriter for use with ConversionPatterns.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp)
This class provides an abstraction over the different types of ranges over Values.
LLVM::LLVMDialect * getDialect()
Returns the LLVM dialect.