28 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
29 #include "mlir/Conversion/Passes.h.inc"
36 static constexpr StringLiteral kInMemoryTileIdAttr(
"arm_sme.in_memory_tile_id");
39 static Operation *createLoadTileSliceIntrinsic(
41 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
42 IntegerAttr tileId,
Value tileSliceI32) {
43 if (layout == arm_sme::TileSliceLayout::Horizontal) {
45 case arm_sme::ArmSMETileType::ZAB:
46 return rewriter.
create<arm_sme::aarch64_sme_ld1b_horiz>(
47 loc, maskOp, ptr, tileId, tileSliceI32);
48 case arm_sme::ArmSMETileType::ZAH:
49 return rewriter.
create<arm_sme::aarch64_sme_ld1h_horiz>(
50 loc, maskOp, ptr, tileId, tileSliceI32);
51 case arm_sme::ArmSMETileType::ZAS:
52 return rewriter.
create<arm_sme::aarch64_sme_ld1w_horiz>(
53 loc, maskOp, ptr, tileId, tileSliceI32);
54 case arm_sme::ArmSMETileType::ZAD:
55 return rewriter.
create<arm_sme::aarch64_sme_ld1d_horiz>(
56 loc, maskOp, ptr, tileId, tileSliceI32);
57 case arm_sme::ArmSMETileType::ZAQ:
58 return rewriter.
create<arm_sme::aarch64_sme_ld1q_horiz>(
59 loc, maskOp, ptr, tileId, tileSliceI32);
63 case arm_sme::ArmSMETileType::ZAB:
64 return rewriter.
create<arm_sme::aarch64_sme_ld1b_vert>(
65 loc, maskOp, ptr, tileId, tileSliceI32);
66 case arm_sme::ArmSMETileType::ZAH:
67 return rewriter.
create<arm_sme::aarch64_sme_ld1h_vert>(
68 loc, maskOp, ptr, tileId, tileSliceI32);
69 case arm_sme::ArmSMETileType::ZAS:
70 return rewriter.
create<arm_sme::aarch64_sme_ld1w_vert>(
71 loc, maskOp, ptr, tileId, tileSliceI32);
72 case arm_sme::ArmSMETileType::ZAD:
73 return rewriter.
create<arm_sme::aarch64_sme_ld1d_vert>(
74 loc, maskOp, ptr, tileId, tileSliceI32);
75 case arm_sme::ArmSMETileType::ZAQ:
76 return rewriter.
create<arm_sme::aarch64_sme_ld1q_vert>(
77 loc, maskOp, ptr, tileId, tileSliceI32);
84 static Operation *createStoreTileSliceIntrinsic(
86 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
87 IntegerAttr tileId,
Value tileSliceI32) {
88 if (layout == arm_sme::TileSliceLayout::Horizontal) {
90 case arm_sme::ArmSMETileType::ZAB:
91 return rewriter.
create<arm_sme::aarch64_sme_st1b_horiz>(
92 loc, maskOp, ptr, tileId, tileSliceI32);
93 case arm_sme::ArmSMETileType::ZAH:
94 return rewriter.
create<arm_sme::aarch64_sme_st1h_horiz>(
95 loc, maskOp, ptr, tileId, tileSliceI32);
96 case arm_sme::ArmSMETileType::ZAS:
97 return rewriter.
create<arm_sme::aarch64_sme_st1w_horiz>(
98 loc, maskOp, ptr, tileId, tileSliceI32);
99 case arm_sme::ArmSMETileType::ZAD:
100 return rewriter.
create<arm_sme::aarch64_sme_st1d_horiz>(
101 loc, maskOp, ptr, tileId, tileSliceI32);
102 case arm_sme::ArmSMETileType::ZAQ:
103 return rewriter.
create<arm_sme::aarch64_sme_st1q_horiz>(
104 loc, maskOp, ptr, tileId, tileSliceI32);
108 case arm_sme::ArmSMETileType::ZAB:
109 return rewriter.
create<arm_sme::aarch64_sme_st1b_vert>(
110 loc, maskOp, ptr, tileId, tileSliceI32);
111 case arm_sme::ArmSMETileType::ZAH:
112 return rewriter.
create<arm_sme::aarch64_sme_st1h_vert>(
113 loc, maskOp, ptr, tileId, tileSliceI32);
114 case arm_sme::ArmSMETileType::ZAS:
115 return rewriter.
create<arm_sme::aarch64_sme_st1w_vert>(
116 loc, maskOp, ptr, tileId, tileSliceI32);
117 case arm_sme::ArmSMETileType::ZAD:
118 return rewriter.
create<arm_sme::aarch64_sme_st1d_vert>(
119 loc, maskOp, ptr, tileId, tileSliceI32);
120 case arm_sme::ArmSMETileType::ZAQ:
121 return rewriter.
create<arm_sme::aarch64_sme_st1q_vert>(
122 loc, maskOp, ptr, tileId, tileSliceI32);
127 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
128 auto tileId = op.getTileId();
131 "expected tile ID to be allocated before conversion to LLVM");
137 static memref::AllocaOp
139 FunctionOpInterface func,
140 arm_sme::ArmSMETileOpInterface tileOp) {
141 RewriterBase::InsertionGuard g(rewriter);
145 auto vscale = rewriter.
create<vector::VectorScaleOp>(loc);
146 auto tileElementType = tileOp.getTileType().getElementType();
148 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
151 rewriter.
create<arith::ConstantIndexOp>(loc, minElements);
152 auto vectorLen = rewriter.
create<arith::MulIOp>(loc, vscale, minElementsOp);
153 auto alloca = rewriter.
create<memref::AllocaOp>(
154 loc, memrefType,
ValueRange{vectorLen, vectorLen});
159 static memref::AllocaOp getOrCreateAllocaForTile(
161 arm_sme::ArmSMETileOpInterface tileOp,
unsigned tileId) {
164 for (
auto &op : func.getBlocks().front()) {
165 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
168 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
169 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
172 if (inMemoryTileId.getInt() == tileId)
176 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
177 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
234 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
238 typeConverter, benefit) {}
243 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
245 if (!tileOp.isInMemoryTile())
250 auto loc = tileOp.getLoc();
251 auto func = tileOp->getParentOfType<FunctionOpInterface>();
252 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
253 tileOp.getTileId().getInt());
258 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
260 VectorType tileVectorType = tileOp.getTileType();
262 auto swapInMemoryTileWithSMETileZero = [&] {
263 emitFullTileSwap(rewriter, loc, tileAlloca,
274 swapInMemoryTileWithSMETileZero();
277 swapInMemoryTileWithSMETileZero();
286 auto llvmType = getTypeConverter()->convertType(tileMemory.
getType());
288 rewriter.
create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
289 auto zero = rewriter.
create<arith::ConstantIntOp>(loc, 0, 64);
290 auto sliceIndexI64 = rewriter.
create<arith::IndexCastOp>(
292 return getStridedElementPtr(
293 loc, llvm::cast<MemRefType>(tileMemory.
getType()),
294 descriptor.getResult(0), {sliceIndexI64, zero},
301 arm_sme::ArmSMETileType tileType, VectorType sliceType,
302 IntegerAttr tileId,
Value sliceIndex)
const {
304 auto sliceIndexI32 = rewriter.
create<arith::IndexCastOp>(
308 auto allTruePredicate = rewriter.
create<arith::ConstantOp>(
311 auto padVector = rewriter.
create<LLVM::UndefOp>(loc, sliceType);
314 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
316 auto currentTileSlice = rewriter.
create<arm_sme::aarch64_sme_read_horiz>(
317 loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
319 createLoadTileSliceIntrinsic(
320 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
321 allTruePredicate, slicePtr, tileId, sliceIndexI32);
323 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
324 rewriter.
create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
331 arm_sme::ArmSMETileType tileType, VectorType sliceType,
332 IntegerAttr tileId)
const {
333 RewriterBase::InsertionGuard guard(rewriter);
336 rewriter.
create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
337 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
338 auto upperBound = rewriter.
create<arith::MulIOp>(
339 loc, minNumElts, rewriter.
create<vector::VectorScaleOp>(loc));
340 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
341 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
344 auto sliceIndex = forOp.getInductionVar();
345 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
350 enum class RequiresSpillsAndFills { Yes, No };
356 template <
typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
357 RequiresSpillsAndFills::Yes>
359 using ArmSMEOp = SourceOp;
362 static constexpr
bool requiresSpillsAndFillsConversion() {
363 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
367 template <
typename Pattern>
373 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
374 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
375 typename Pattern::ArmSMEOp>,
376 typename Pattern::ArmSMEOp>) {
379 patterns.
add<ConvertArmSMESpillsAndFillsToLLVM>(
380 Pattern::ArmSMEOp::getOperationName(), typeConverter,
387 template <
typename... Patterns>
391 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
394 struct GetTileConversion
395 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::GetTileOp,
396 RequiresSpillsAndFills::No> {
397 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
400 matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
403 getTile, getTile.getTileType());
423 struct ZeroOpConversion :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
424 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
427 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
429 auto loc = zero.getLoc();
431 auto tileId = getTileIdOrError(zero);
439 arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
440 auto baseMaskForSize = [&] {
442 case arm_sme::ArmSMETileType::ZAB:
446 case arm_sme::ArmSMETileType::ZAH:
451 case arm_sme::ArmSMETileType::ZAS:
456 case arm_sme::ArmSMETileType::ZAD:
461 llvm_unreachable(
"bad element size");
486 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
487 rewriter.
create<arm_sme::aarch64_sme_zero>(
492 zero, zero.getVectorType());
499 struct LoadTileSliceConversion
500 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
501 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
504 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
505 arm_sme::LoadTileSliceOp::Adaptor adaptor,
507 auto loc = loadTileSliceOp.getLoc();
508 auto tileId = getTileIdOrError(loadTileSliceOp);
512 Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
514 adaptor.getIndices(), rewriter);
516 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
519 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
523 auto maskOp = loadTileSliceOp.getMask();
525 auto tileVectorType = loadTileSliceOp.getVectorType();
527 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
530 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
531 tileId, tileSliceI32);
535 rewriter.
replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
542 struct StoreTileSliceConversion
543 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
544 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
547 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
548 arm_sme::StoreTileSliceOp::Adaptor adaptor,
550 auto loc = storeTileSliceOp.getLoc();
551 auto tileVectorType = storeTileSliceOp.getVectorType();
553 auto tileId = getTileIdOrError(storeTileSliceOp);
558 Value ptr = this->getStridedElementPtr(
559 loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
560 adaptor.getIndices(), rewriter);
562 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
565 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
568 auto maskOp = storeTileSliceOp.getMask();
570 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
574 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
576 tileId, tileSliceI32));
583 struct MoveVectorToTileSliceConversion
584 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
585 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
588 matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp,
589 arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor,
591 auto loc = moveVectorToTileSliceOp.getLoc();
592 auto tileType = moveVectorToTileSliceOp.getTileType();
594 auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
598 auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
601 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
605 auto one = rewriter.
create<arith::ConstantOp>(
610 auto allActiveMask = rewriter.
create<vector::SplatOp>(loc, predTy, one);
613 switch (moveVectorToTileSliceOp.getLayout()) {
614 case arm_sme::TileSliceLayout::Horizontal:
615 rewriter.
create<arm_sme::aarch64_sme_write_horiz>(
616 loc, tileId, tileSliceI32, allActiveMask,
617 moveVectorToTileSliceOp.getVector());
619 case arm_sme::TileSliceLayout::Vertical:
620 rewriter.
create<arm_sme::aarch64_sme_write_vert>(
621 loc, tileId, tileSliceI32, allActiveMask,
622 moveVectorToTileSliceOp.getVector());
628 rewriter.
replaceOp(moveVectorToTileSliceOp,
629 moveVectorToTileSliceOp.getTile());
636 struct MoveTileSliceToVectorConversion
637 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
638 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
641 matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
644 auto loc = moveTileSliceToVector.getLoc();
645 auto sliceType = moveTileSliceToVector.getSliceType();
646 auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
648 auto tileId = getTileIdOrError(moveTileSliceToVector);
653 auto predicateType = sliceType.cloneWith({}, rewriter.
getI1Type());
654 auto allTruePredicate = rewriter.
create<arith::ConstantOp>(
658 auto zeroVector = rewriter.
create<arith::ConstantOp>(
662 auto sliceIndexI32 = rewriter.
create<arith::IndexCastOp>(
666 switch (moveTileSliceToVector.getLayout()) {
667 case arm_sme::TileSliceLayout::Horizontal:
669 moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
670 tileId, sliceIndexI32);
672 case arm_sme::TileSliceLayout::Vertical:
674 moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
675 tileId, sliceIndexI32);
697 struct OuterProductOpConversion
698 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
699 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
702 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
703 arm_sme::OuterProductOp::Adaptor adaptor,
705 auto tileId = getTileIdOrError(outerProductOp);
709 auto isSupportedType = [](VectorType vectorType) {
723 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
726 auto elementType = vectorType.getElementType();
728 if (!elementType.isF16() && !elementType.isBF16() &&
729 !elementType.isF32() && !elementType.isF64())
733 vectorType.getElementTypeBitWidth();
734 return vectorType.getShape() ==
739 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
740 return outerProductOp.emitError(
"unsupported kind");
742 auto resultVectorType = outerProductOp.getResultType();
743 if (!isSupportedType(resultVectorType))
744 return outerProductOp.emitError(
"unsupported type");
746 auto loc = outerProductOp.getLoc();
748 Value acc = outerProductOp.getAcc();
751 acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
752 rewriter, loc, resultVectorType);
754 Value lhsMask = outerProductOp.getLhsMask();
755 Value rhsMask = outerProductOp.getRhsMask();
757 if (!lhsMask || !rhsMask) {
759 outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
760 Value allActiveMask = rewriter.create<arith::ConstantOp>(
762 lhsMask = allActiveMask;
763 rhsMask = allActiveMask;
767 rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
768 outerProductOp.getLhs(),
769 outerProductOp.getRhs());
773 rewriter.replaceOp(outerProductOp, acc);
780 template <
class OuterProductW
ideningOp,
class OuterProductW
ideningIntrOp>
781 struct OuterProductWideningOpConversion
782 :
public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
783 using ConvertArmSMEOpToLLVMPattern<
784 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
787 matchAndRewrite(OuterProductWideningOp op,
788 typename OuterProductWideningOp::Adaptor adaptor,
790 auto tileId = getTileIdOrError(op);
794 Value acc = op.getAcc();
797 acc = op.template createOpAndForwardTileId<arm_sme::ZeroOp>(
798 rewriter, op.
getLoc(), op.getResultType());
800 Value lhsMask = op.getLhsMask();
801 Value rhsMask = op.getRhsMask();
802 if (!lhsMask || !rhsMask) {
803 auto predTy = op.getLhsType().cloneWith({}, rewriter.
getI1Type());
804 Value allActiveMask = rewriter.
create<arith::ConstantOp>(
806 lhsMask = allActiveMask;
807 rhsMask = allActiveMask;
810 rewriter.
create<OuterProductWideningIntrOp>(op.
getLoc(), tileId, lhsMask,
811 rhsMask, adaptor.getLhs(),
833 struct StreamingVLOpConversion
834 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
835 RequiresSpillsAndFills::No> {
836 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
839 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
840 arm_sme::StreamingVLOp::Adaptor adaptor,
842 auto loc = streamingVlOp.getLoc();
845 switch (streamingVlOp.getTypeSize()) {
846 case arm_sme::TypeSize::Byte:
847 return rewriter.
create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
848 case arm_sme::TypeSize::Half:
849 return rewriter.
create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
850 case arm_sme::TypeSize::Word:
851 return rewriter.
create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
852 case arm_sme::TypeSize::Double:
853 return rewriter.
create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
857 streamingVlOp, rewriter.
getIndexType(), intrOp->getResult(0));
866 struct ConvertArmSMEToLLVMPass
867 :
public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
868 void runOnOperation()
override {
876 std::move(patterns))))
886 arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
887 arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
888 arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
889 arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
890 arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
891 arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
892 arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
893 arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
894 arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
895 arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
896 arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
897 arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
898 arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
899 arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
900 arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide,
901 arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide,
902 arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide,
903 arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32,
904 arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32,
905 arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide,
906 arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide,
907 arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
908 arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
911 vector::VectorDialect, scf::SCFDialect,
912 memref::MemRefDialect>();
913 target.
addLegalOp<UnrealizedConversionCastOp>();
918 converter.
addConversion([&](VectorType type) -> std::optional<Type> {
926 addArmSMEConversionPatterns<
927 LoadTileSliceConversion, MoveTileSliceToVectorConversion,
928 MoveVectorToTileSliceConversion, StoreTileSliceConversion,
929 StreamingVLOpConversion, OuterProductOpConversion,
930 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
931 arm_sme::aarch64_sme_mopa_wide>,
932 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
933 arm_sme::aarch64_sme_mops_wide>,
934 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
935 arm_sme::aarch64_sme_smopa_za32>,
936 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
937 arm_sme::aarch64_sme_smops_za32>,
938 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
939 arm_sme::aarch64_sme_umopa_za32>,
940 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
941 arm_sme::aarch64_sme_umops_za32>,
942 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
943 arm_sme::aarch64_sme_smopa_wide>,
944 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
945 arm_sme::aarch64_sme_smops_wide>,
946 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
947 arm_sme::aarch64_sme_umopa_wide>,
948 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
949 arm_sme::aarch64_sme_umops_wide>,
950 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
951 arm_sme::aarch64_sme_sumopa_wide>,
952 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
953 arm_sme::aarch64_sme_sumops_wide>,
954 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
955 arm_sme::aarch64_sme_usmopa_wide>,
956 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
957 arm_sme::aarch64_sme_usmops_wide>,
958 ZeroOpConversion, GetTileConversion>(patterns, converter);
962 return std::make_unique<ConvertArmSMEToLLVMPass>();
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for operation conversions targeting the LLVM IR dialect.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
constexpr unsigned MinStreamingVectorLengthInBits
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass()
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.
This class represents an efficient way to signal success or failure.