28 #include "llvm/ADT/ScopeExit.h"
31 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
32 #include "mlir/Conversion/Passes.h.inc"
39 static constexpr StringLiteral kInMemoryTileIdAttr(
"arm_sme.in_memory_tile_id");
42 static Operation *createLoadTileSliceIntrinsic(
44 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
45 IntegerAttr tileId,
Value tileSliceI32) {
46 if (layout == arm_sme::TileSliceLayout::Horizontal) {
48 case arm_sme::ArmSMETileType::ZAB:
49 return rewriter.
create<arm_sme::aarch64_sme_ld1b_horiz>(
50 loc, maskOp, ptr, tileId, tileSliceI32);
51 case arm_sme::ArmSMETileType::ZAH:
52 return rewriter.
create<arm_sme::aarch64_sme_ld1h_horiz>(
53 loc, maskOp, ptr, tileId, tileSliceI32);
54 case arm_sme::ArmSMETileType::ZAS:
55 return rewriter.
create<arm_sme::aarch64_sme_ld1w_horiz>(
56 loc, maskOp, ptr, tileId, tileSliceI32);
57 case arm_sme::ArmSMETileType::ZAD:
58 return rewriter.
create<arm_sme::aarch64_sme_ld1d_horiz>(
59 loc, maskOp, ptr, tileId, tileSliceI32);
60 case arm_sme::ArmSMETileType::ZAQ:
61 return rewriter.
create<arm_sme::aarch64_sme_ld1q_horiz>(
62 loc, maskOp, ptr, tileId, tileSliceI32);
66 case arm_sme::ArmSMETileType::ZAB:
67 return rewriter.
create<arm_sme::aarch64_sme_ld1b_vert>(
68 loc, maskOp, ptr, tileId, tileSliceI32);
69 case arm_sme::ArmSMETileType::ZAH:
70 return rewriter.
create<arm_sme::aarch64_sme_ld1h_vert>(
71 loc, maskOp, ptr, tileId, tileSliceI32);
72 case arm_sme::ArmSMETileType::ZAS:
73 return rewriter.
create<arm_sme::aarch64_sme_ld1w_vert>(
74 loc, maskOp, ptr, tileId, tileSliceI32);
75 case arm_sme::ArmSMETileType::ZAD:
76 return rewriter.
create<arm_sme::aarch64_sme_ld1d_vert>(
77 loc, maskOp, ptr, tileId, tileSliceI32);
78 case arm_sme::ArmSMETileType::ZAQ:
79 return rewriter.
create<arm_sme::aarch64_sme_ld1q_vert>(
80 loc, maskOp, ptr, tileId, tileSliceI32);
87 static Operation *createStoreTileSliceIntrinsic(
89 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
90 IntegerAttr tileId,
Value tileSliceI32) {
91 if (layout == arm_sme::TileSliceLayout::Horizontal) {
93 case arm_sme::ArmSMETileType::ZAB:
94 return rewriter.
create<arm_sme::aarch64_sme_st1b_horiz>(
95 loc, maskOp, ptr, tileId, tileSliceI32);
96 case arm_sme::ArmSMETileType::ZAH:
97 return rewriter.
create<arm_sme::aarch64_sme_st1h_horiz>(
98 loc, maskOp, ptr, tileId, tileSliceI32);
99 case arm_sme::ArmSMETileType::ZAS:
100 return rewriter.
create<arm_sme::aarch64_sme_st1w_horiz>(
101 loc, maskOp, ptr, tileId, tileSliceI32);
102 case arm_sme::ArmSMETileType::ZAD:
103 return rewriter.
create<arm_sme::aarch64_sme_st1d_horiz>(
104 loc, maskOp, ptr, tileId, tileSliceI32);
105 case arm_sme::ArmSMETileType::ZAQ:
106 return rewriter.
create<arm_sme::aarch64_sme_st1q_horiz>(
107 loc, maskOp, ptr, tileId, tileSliceI32);
111 case arm_sme::ArmSMETileType::ZAB:
112 return rewriter.
create<arm_sme::aarch64_sme_st1b_vert>(
113 loc, maskOp, ptr, tileId, tileSliceI32);
114 case arm_sme::ArmSMETileType::ZAH:
115 return rewriter.
create<arm_sme::aarch64_sme_st1h_vert>(
116 loc, maskOp, ptr, tileId, tileSliceI32);
117 case arm_sme::ArmSMETileType::ZAS:
118 return rewriter.
create<arm_sme::aarch64_sme_st1w_vert>(
119 loc, maskOp, ptr, tileId, tileSliceI32);
120 case arm_sme::ArmSMETileType::ZAD:
121 return rewriter.
create<arm_sme::aarch64_sme_st1d_vert>(
122 loc, maskOp, ptr, tileId, tileSliceI32);
123 case arm_sme::ArmSMETileType::ZAQ:
124 return rewriter.
create<arm_sme::aarch64_sme_st1q_vert>(
125 loc, maskOp, ptr, tileId, tileSliceI32);
130 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
131 auto tileId = op.getTileId();
134 "expected tile ID to be allocated before conversion to LLVM");
140 static memref::AllocaOp
142 FunctionOpInterface func,
143 arm_sme::ArmSMETileOpInterface tileOp) {
144 RewriterBase::InsertionGuard g(rewriter);
148 auto vscale = rewriter.
create<vector::VectorScaleOp>(loc);
149 auto tileElementType = tileOp.getTileType().getElementType();
151 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
154 rewriter.
create<arith::ConstantIndexOp>(loc, minElements);
155 auto vectorLen = rewriter.
create<arith::MulIOp>(loc, vscale, minElementsOp);
156 auto alloca = rewriter.
create<memref::AllocaOp>(
157 loc, memrefType,
ValueRange{vectorLen, vectorLen});
162 static memref::AllocaOp getOrCreateAllocaForTile(
164 arm_sme::ArmSMETileOpInterface tileOp,
unsigned tileId) {
167 for (
auto &op : func.getBlocks().front()) {
168 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
171 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
172 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
175 if (inMemoryTileId.getInt() == tileId)
179 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
180 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
237 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
241 typeConverter, benefit) {}
246 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
248 if (!tileOp.isInMemoryTile())
252 "failed to allocate SME virtual tile to operation, tile value will go "
253 "through memory, expect degraded performance");
257 auto loc = tileOp.getLoc();
258 auto func = tileOp->getParentOfType<FunctionOpInterface>();
259 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
260 tileOp.getTileId().getInt());
265 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
267 VectorType tileVectorType = tileOp.getTileType();
269 auto swapInMemoryTileWithSMETileZero = [&] {
270 emitFullTileSwap(rewriter, loc, tileAlloca,
281 swapInMemoryTileWithSMETileZero();
284 swapInMemoryTileWithSMETileZero();
293 auto llvmType = getTypeConverter()->convertType(tileMemory.
getType());
295 rewriter.
create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
296 auto zero = rewriter.
create<arith::ConstantIntOp>(loc, 0, 64);
297 auto sliceIndexI64 = rewriter.
create<arith::IndexCastOp>(
299 return getStridedElementPtr(
300 loc, llvm::cast<MemRefType>(tileMemory.
getType()),
301 descriptor.getResult(0), {sliceIndexI64, zero},
308 arm_sme::ArmSMETileType tileType, VectorType sliceType,
309 IntegerAttr tileId,
Value sliceIndex)
const {
311 auto sliceIndexI32 = rewriter.
create<arith::IndexCastOp>(
315 auto allTruePredicate = rewriter.
create<arith::ConstantOp>(
318 auto padVector = rewriter.
create<LLVM::UndefOp>(loc, sliceType);
321 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
323 auto currentTileSlice = rewriter.
create<arm_sme::aarch64_sme_read_horiz>(
324 loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
326 createLoadTileSliceIntrinsic(
327 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
328 allTruePredicate, slicePtr, tileId, sliceIndexI32);
330 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
331 rewriter.
create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
338 arm_sme::ArmSMETileType tileType, VectorType sliceType,
339 IntegerAttr tileId)
const {
340 RewriterBase::InsertionGuard guard(rewriter);
343 rewriter.
create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
344 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
345 auto upperBound = rewriter.
create<arith::MulIOp>(
346 loc, minNumElts, rewriter.
create<vector::VectorScaleOp>(loc));
347 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
348 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
351 auto sliceIndex = forOp.getInductionVar();
352 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
357 enum class RequiresSpillsAndFills { Yes, No };
363 template <
typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
364 RequiresSpillsAndFills::Yes>
366 using ArmSMEOp = SourceOp;
369 static constexpr
bool requiresSpillsAndFillsConversion() {
370 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
374 template <
typename Pattern>
380 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
381 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
382 typename Pattern::ArmSMEOp>,
383 typename Pattern::ArmSMEOp>) {
386 patterns.
add<ConvertArmSMESpillsAndFillsToLLVM>(
387 Pattern::ArmSMEOp::getOperationName(), typeConverter,
394 template <
typename... Patterns>
398 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
416 struct ZeroOpConversion :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
417 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
420 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
422 auto loc = zero.getLoc();
424 auto tileId = getTileIdOrError(zero);
432 arm_sme::ArmSMETileType tileType =
434 auto baseMaskForSize = [&] {
436 case arm_sme::ArmSMETileType::ZAB:
440 case arm_sme::ArmSMETileType::ZAH:
445 case arm_sme::ArmSMETileType::ZAS:
450 case arm_sme::ArmSMETileType::ZAD:
455 llvm_unreachable(
"bad element size");
480 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
481 rewriter.
create<arm_sme::aarch64_sme_zero>(
495 struct LoadTileSliceConversion
496 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
497 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
500 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
501 arm_sme::LoadTileSliceOp::Adaptor adaptor,
503 auto loc = loadTileSliceOp.getLoc();
504 auto tileId = getTileIdOrError(loadTileSliceOp);
508 Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
510 adaptor.getIndices(), rewriter);
512 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
515 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
519 auto maskOp = loadTileSliceOp.getMask();
521 auto tileVectorType = loadTileSliceOp.getVectorType();
523 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
526 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
527 tileId, tileSliceI32);
531 rewriter.
replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
538 struct StoreTileSliceConversion
539 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
540 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
543 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
544 arm_sme::StoreTileSliceOp::Adaptor adaptor,
546 auto loc = storeTileSliceOp.getLoc();
547 auto tileVectorType = storeTileSliceOp.getVectorType();
549 auto tileId = getTileIdOrError(storeTileSliceOp);
554 Value ptr = this->getStridedElementPtr(
555 loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
556 adaptor.getIndices(), rewriter);
558 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
561 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
564 auto maskOp = storeTileSliceOp.getMask();
566 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
570 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
572 tileId, tileSliceI32));
579 struct InsertTileSliceConversion
580 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
581 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
584 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
585 arm_sme::InsertTileSliceOp::Adaptor adaptor,
587 auto loc = insertTileSliceOp.getLoc();
588 auto tileType = insertTileSliceOp.getTileType();
590 auto tileId = getTileIdOrError(insertTileSliceOp);
594 auto tileSlice = insertTileSliceOp.getTileSliceIndex();
597 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
601 auto one = rewriter.
create<arith::ConstantOp>(
606 auto allActiveMask = rewriter.
create<vector::SplatOp>(loc, predTy, one);
609 switch (insertTileSliceOp.getLayout()) {
610 case arm_sme::TileSliceLayout::Horizontal:
611 rewriter.
create<arm_sme::aarch64_sme_write_horiz>(
612 loc, tileId, tileSliceI32, allActiveMask,
613 insertTileSliceOp.getVector());
615 case arm_sme::TileSliceLayout::Vertical:
616 rewriter.
create<arm_sme::aarch64_sme_write_vert>(
617 loc, tileId, tileSliceI32, allActiveMask,
618 insertTileSliceOp.getVector());
624 rewriter.
replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
631 struct ExtractTileSliceConversion
632 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
633 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
636 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
638 auto loc = extractTileSlice.getLoc();
639 auto sliceType = extractTileSlice.getSliceType();
640 auto sliceIndex = extractTileSlice.getTileSliceIndex();
642 auto tileId = getTileIdOrError(extractTileSlice);
647 auto predicateType = sliceType.cloneWith({}, rewriter.
getI1Type());
648 auto allTruePredicate = rewriter.
create<arith::ConstantOp>(
652 auto zeroVector = rewriter.
create<arith::ConstantOp>(
656 auto sliceIndexI32 = rewriter.
create<arith::IndexCastOp>(
660 switch (extractTileSlice.getLayout()) {
661 case arm_sme::TileSliceLayout::Horizontal:
663 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
666 case arm_sme::TileSliceLayout::Vertical:
668 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
691 struct OuterProductOpConversion
692 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
693 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
696 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
697 arm_sme::OuterProductOp::Adaptor adaptor,
699 auto tileId = getTileIdOrError(outerProductOp);
703 auto isSupportedType = [](VectorType vectorType) {
717 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
720 auto elementType = vectorType.getElementType();
722 if (!elementType.isF16() && !elementType.isBF16() &&
723 !elementType.isF32() && !elementType.isF64())
727 vectorType.getElementTypeBitWidth();
728 return vectorType.getShape() ==
733 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
734 return outerProductOp.emitError(
"unsupported kind");
736 auto resultVectorType = outerProductOp.getResultType();
737 if (!isSupportedType(resultVectorType))
738 return outerProductOp.emitError(
"unsupported type");
740 auto loc = outerProductOp.getLoc();
742 Value acc = outerProductOp.getAcc();
745 auto zero = rewriter.
create<arm_sme::ZeroOp>(loc, resultVectorType);
746 zero.setTileId(tileId);
750 Value lhsMask = outerProductOp.getLhsMask();
751 Value rhsMask = outerProductOp.getRhsMask();
753 if (!lhsMask || !rhsMask) {
755 outerProductOp.getLhsType().cloneWith({}, rewriter.
getI1Type());
756 Value allActiveMask = rewriter.
create<arith::ConstantOp>(
758 lhsMask = allActiveMask;
759 rhsMask = allActiveMask;
763 rewriter.
create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
764 outerProductOp.getLhs(),
765 outerProductOp.getRhs());
776 template <
class OuterProductW
ideningOp,
class OuterProductW
ideningIntrOp>
777 struct OuterProductWideningOpConversion
778 :
public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
779 using ConvertArmSMEOpToLLVMPattern<
780 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
783 matchAndRewrite(OuterProductWideningOp op,
784 typename OuterProductWideningOp::Adaptor adaptor,
786 auto tileId = getTileIdOrError(op);
791 Value acc = op.getAcc();
794 auto zero = rewriter.
create<arm_sme::ZeroOp>(loc, op.getResultType());
795 zero.setTileId(tileId);
799 Value lhsMask = op.getLhsMask();
800 Value rhsMask = op.getRhsMask();
801 if (!lhsMask || !rhsMask) {
802 auto predTy = op.getLhsType().cloneWith({}, rewriter.
getI1Type());
803 Value allActiveMask = rewriter.
create<arith::ConstantOp>(
805 lhsMask = allActiveMask;
806 rhsMask = allActiveMask;
809 rewriter.
create<OuterProductWideningIntrOp>(
810 loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
831 struct StreamingVLOpConversion
832 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
833 RequiresSpillsAndFills::No> {
834 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
837 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
838 arm_sme::StreamingVLOp::Adaptor adaptor,
840 auto loc = streamingVlOp.getLoc();
843 switch (streamingVlOp.getTypeSize()) {
844 case arm_sme::TypeSize::Byte:
845 return rewriter.
create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
846 case arm_sme::TypeSize::Half:
847 return rewriter.
create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
848 case arm_sme::TypeSize::Word:
849 return rewriter.
create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
850 case arm_sme::TypeSize::Double:
851 return rewriter.
create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
855 streamingVlOp, rewriter.
getIndexType(), intrOp->getResult(0));
862 static void mergeConsecutiveTileZerosInBlock(
Block *block) {
863 uint32_t mergedZeroMask = 0;
865 auto replaceMergedZeroOps = [&] {
866 auto cleanup = llvm::make_scope_exit([&] {
868 zeroOpsToMerge.clear();
870 if (zeroOpsToMerge.size() <= 1)
873 rewriter.
create<arm_sme::aarch64_sme_zero>(
874 zeroOpsToMerge.front().getLoc(),
876 for (
auto zeroOp : zeroOpsToMerge)
880 if (
auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
881 mergedZeroMask |= zeroOp.getTileMask();
882 zeroOpsToMerge.push_back(zeroOp);
884 replaceMergedZeroOps();
887 replaceMergedZeroOps();
894 struct ConvertArmSMEToLLVMPass
895 :
public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
896 ConvertArmSMEToLLVMPass(
bool dumpTileLiveRanges) {
897 this->dumpTileLiveRanges = dumpTileLiveRanges;
899 void runOnOperation()
override {
900 auto function = getOperation();
903 return signalPassFailure();
914 function->walk(mergeConsecutiveTileZerosInBlock);
920 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
923 auto isSMETileType = [](
Type type) {
924 return arm_sme::isValidSMETileVectorType(type);
928 op->emitOpError(
"unexpected operation with SME tile type after "
929 "conversion to LLVM");
941 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
942 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
943 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
944 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
945 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
946 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
947 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
948 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
949 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
950 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
951 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
952 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
953 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
954 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
955 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
956 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
957 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
958 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
959 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
960 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
961 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
962 arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
963 arm_sme::aarch64_sme_cntsd>();
966 vector::VectorDialect, scf::SCFDialect,
967 memref::MemRefDialect>();
971 target.
addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
972 UnrealizedConversionCastOp>();
977 converter.
addConversion([&](VectorType type) -> std::optional<Type> {
985 addArmSMEConversionPatterns<
986 LoadTileSliceConversion, ExtractTileSliceConversion,
987 InsertTileSliceConversion, StoreTileSliceConversion,
988 StreamingVLOpConversion, OuterProductOpConversion,
989 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
990 arm_sme::aarch64_sme_mopa_wide>,
991 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
992 arm_sme::aarch64_sme_mops_wide>,
993 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
994 arm_sme::aarch64_sme_smopa_za32>,
995 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
996 arm_sme::aarch64_sme_smops_za32>,
997 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
998 arm_sme::aarch64_sme_umopa_za32>,
999 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1000 arm_sme::aarch64_sme_umops_za32>,
1001 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1002 arm_sme::aarch64_sme_smopa_wide>,
1003 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1004 arm_sme::aarch64_sme_smops_wide>,
1005 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1006 arm_sme::aarch64_sme_umopa_wide>,
1007 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1008 arm_sme::aarch64_sme_umops_wide>,
1009 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1010 arm_sme::aarch64_sme_sumopa_wide>,
1011 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1012 arm_sme::aarch64_sme_sumops_wide>,
1013 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1014 arm_sme::aarch64_sme_usmopa_wide>,
1015 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1016 arm_sme::aarch64_sme_usmops_wide>,
1017 ZeroOpConversion>(patterns, converter);
1020 std::unique_ptr<Pass>
1022 return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for operation conversions targeting the LLVM IR dialect.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
constexpr unsigned MinStreamingVectorLengthInBits
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.