28 #include "llvm/ADT/ScopeExit.h"
31 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
32 #include "mlir/Conversion/Passes.h.inc"
39 static constexpr StringLiteral kInMemoryTileIdAttr(
"arm_sme.in_memory_tile_id");
42 static Operation *createLoadTileSliceIntrinsic(
44 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
45 IntegerAttr tileId,
Value tileSliceI32) {
46 if (layout == arm_sme::TileSliceLayout::Horizontal) {
48 case arm_sme::ArmSMETileType::ZAB:
49 return rewriter.
create<arm_sme::aarch64_sme_ld1b_horiz>(
50 loc, maskOp, ptr, tileId, tileSliceI32);
51 case arm_sme::ArmSMETileType::ZAH:
52 return rewriter.
create<arm_sme::aarch64_sme_ld1h_horiz>(
53 loc, maskOp, ptr, tileId, tileSliceI32);
54 case arm_sme::ArmSMETileType::ZAS:
55 return rewriter.
create<arm_sme::aarch64_sme_ld1w_horiz>(
56 loc, maskOp, ptr, tileId, tileSliceI32);
57 case arm_sme::ArmSMETileType::ZAD:
58 return rewriter.
create<arm_sme::aarch64_sme_ld1d_horiz>(
59 loc, maskOp, ptr, tileId, tileSliceI32);
60 case arm_sme::ArmSMETileType::ZAQ:
61 return rewriter.
create<arm_sme::aarch64_sme_ld1q_horiz>(
62 loc, maskOp, ptr, tileId, tileSliceI32);
66 case arm_sme::ArmSMETileType::ZAB:
67 return rewriter.
create<arm_sme::aarch64_sme_ld1b_vert>(
68 loc, maskOp, ptr, tileId, tileSliceI32);
69 case arm_sme::ArmSMETileType::ZAH:
70 return rewriter.
create<arm_sme::aarch64_sme_ld1h_vert>(
71 loc, maskOp, ptr, tileId, tileSliceI32);
72 case arm_sme::ArmSMETileType::ZAS:
73 return rewriter.
create<arm_sme::aarch64_sme_ld1w_vert>(
74 loc, maskOp, ptr, tileId, tileSliceI32);
75 case arm_sme::ArmSMETileType::ZAD:
76 return rewriter.
create<arm_sme::aarch64_sme_ld1d_vert>(
77 loc, maskOp, ptr, tileId, tileSliceI32);
78 case arm_sme::ArmSMETileType::ZAQ:
79 return rewriter.
create<arm_sme::aarch64_sme_ld1q_vert>(
80 loc, maskOp, ptr, tileId, tileSliceI32);
84 llvm_unreachable(
"unknown type in createLoadTileSliceIntrinsic");
88 static Operation *createStoreTileSliceIntrinsic(
90 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
91 IntegerAttr tileId,
Value tileSliceI32) {
92 if (layout == arm_sme::TileSliceLayout::Horizontal) {
94 case arm_sme::ArmSMETileType::ZAB:
95 return rewriter.
create<arm_sme::aarch64_sme_st1b_horiz>(
96 loc, maskOp, ptr, tileId, tileSliceI32);
97 case arm_sme::ArmSMETileType::ZAH:
98 return rewriter.
create<arm_sme::aarch64_sme_st1h_horiz>(
99 loc, maskOp, ptr, tileId, tileSliceI32);
100 case arm_sme::ArmSMETileType::ZAS:
101 return rewriter.
create<arm_sme::aarch64_sme_st1w_horiz>(
102 loc, maskOp, ptr, tileId, tileSliceI32);
103 case arm_sme::ArmSMETileType::ZAD:
104 return rewriter.
create<arm_sme::aarch64_sme_st1d_horiz>(
105 loc, maskOp, ptr, tileId, tileSliceI32);
106 case arm_sme::ArmSMETileType::ZAQ:
107 return rewriter.
create<arm_sme::aarch64_sme_st1q_horiz>(
108 loc, maskOp, ptr, tileId, tileSliceI32);
112 case arm_sme::ArmSMETileType::ZAB:
113 return rewriter.
create<arm_sme::aarch64_sme_st1b_vert>(
114 loc, maskOp, ptr, tileId, tileSliceI32);
115 case arm_sme::ArmSMETileType::ZAH:
116 return rewriter.
create<arm_sme::aarch64_sme_st1h_vert>(
117 loc, maskOp, ptr, tileId, tileSliceI32);
118 case arm_sme::ArmSMETileType::ZAS:
119 return rewriter.
create<arm_sme::aarch64_sme_st1w_vert>(
120 loc, maskOp, ptr, tileId, tileSliceI32);
121 case arm_sme::ArmSMETileType::ZAD:
122 return rewriter.
create<arm_sme::aarch64_sme_st1d_vert>(
123 loc, maskOp, ptr, tileId, tileSliceI32);
124 case arm_sme::ArmSMETileType::ZAQ:
125 return rewriter.
create<arm_sme::aarch64_sme_st1q_vert>(
126 loc, maskOp, ptr, tileId, tileSliceI32);
129 llvm_unreachable(
"unknown type in createStoreTileSliceIntrinsic");
132 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
133 auto tileId = op.getTileId();
136 "expected tile ID to be allocated before conversion to LLVM");
142 static memref::AllocaOp
144 FunctionOpInterface func,
145 arm_sme::ArmSMETileOpInterface tileOp) {
146 RewriterBase::InsertionGuard g(rewriter);
150 auto vscale = rewriter.
create<vector::VectorScaleOp>(loc);
151 auto tileElementType = tileOp.getTileType().getElementType();
153 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
156 rewriter.
create<arith::ConstantIndexOp>(loc, minElements);
157 auto vectorLen = rewriter.
create<arith::MulIOp>(loc, vscale, minElementsOp);
158 auto alloca = rewriter.
create<memref::AllocaOp>(
159 loc, memrefType,
ValueRange{vectorLen, vectorLen});
164 static memref::AllocaOp getOrCreateAllocaForTile(
166 arm_sme::ArmSMETileOpInterface tileOp,
unsigned tileId) {
169 for (
auto &op : func.getBlocks().front()) {
170 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
173 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
174 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
177 if (inMemoryTileId.getInt() == tileId)
181 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
182 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
239 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
243 typeConverter, benefit) {}
248 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
250 if (!tileOp.isInMemoryTile())
254 "failed to allocate SME virtual tile to operation, tile value will go "
255 "through memory, expect degraded performance");
259 auto loc = tileOp.getLoc();
260 auto func = tileOp->getParentOfType<FunctionOpInterface>();
261 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
262 tileOp.getTileId().getInt());
267 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
269 VectorType tileVectorType = tileOp.getTileType();
271 auto swapInMemoryTileWithSMETileZero = [&] {
272 emitFullTileSwap(rewriter, loc, tileAlloca,
283 swapInMemoryTileWithSMETileZero();
286 swapInMemoryTileWithSMETileZero();
295 auto llvmType = getTypeConverter()->convertType(tileMemory.
getType());
297 rewriter.
create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
298 auto zero = rewriter.
create<arith::ConstantIntOp>(loc, 0, 64);
299 auto sliceIndexI64 = rewriter.
create<arith::IndexCastOp>(
301 return getStridedElementPtr(
302 loc, llvm::cast<MemRefType>(tileMemory.
getType()),
303 descriptor.getResult(0), {sliceIndexI64, zero},
310 arm_sme::ArmSMETileType tileType, VectorType sliceType,
311 IntegerAttr tileId,
Value sliceIndex)
const {
313 auto sliceIndexI32 = rewriter.
create<arith::IndexCastOp>(
317 auto allTruePredicate = rewriter.
create<arith::ConstantOp>(
320 auto padVector = rewriter.
create<LLVM::UndefOp>(loc, sliceType);
323 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
325 auto currentTileSlice = rewriter.
create<arm_sme::aarch64_sme_read_horiz>(
326 loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
328 createLoadTileSliceIntrinsic(
329 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330 allTruePredicate, slicePtr, tileId, sliceIndexI32);
332 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
333 rewriter.
create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
340 arm_sme::ArmSMETileType tileType, VectorType sliceType,
341 IntegerAttr tileId)
const {
342 RewriterBase::InsertionGuard guard(rewriter);
345 rewriter.
create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
346 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
347 auto upperBound = rewriter.
create<arith::MulIOp>(
348 loc, minNumElts, rewriter.
create<vector::VectorScaleOp>(loc));
349 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
350 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
353 auto sliceIndex = forOp.getInductionVar();
354 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
359 enum class RequiresSpillsAndFills { Yes, No };
365 template <
typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
366 RequiresSpillsAndFills::Yes>
368 using ArmSMEOp = SourceOp;
371 static constexpr
bool requiresSpillsAndFillsConversion() {
372 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
376 template <
typename Pattern>
382 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
383 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
384 typename Pattern::ArmSMEOp>,
385 typename Pattern::ArmSMEOp>) {
388 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
389 Pattern::ArmSMEOp::getOperationName(), typeConverter,
396 template <
typename... Patterns>
400 (addArmSMEConversionPattern<Patterns>(
patterns, typeConverter), ...);
418 struct ZeroOpConversion :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
419 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
422 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
424 auto loc = zero.getLoc();
426 auto tileId = getTileIdOrError(zero);
434 arm_sme::ArmSMETileType tileType =
436 auto baseMaskForSize = [&] {
438 case arm_sme::ArmSMETileType::ZAB:
442 case arm_sme::ArmSMETileType::ZAH:
447 case arm_sme::ArmSMETileType::ZAS:
452 case arm_sme::ArmSMETileType::ZAD:
457 llvm_unreachable(
"bad element size");
482 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
483 rewriter.
create<arm_sme::aarch64_sme_zero>(
497 struct LoadTileSliceConversion
498 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
499 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
502 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
503 arm_sme::LoadTileSliceOp::Adaptor adaptor,
505 auto loc = loadTileSliceOp.getLoc();
506 auto tileId = getTileIdOrError(loadTileSliceOp);
510 Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
512 adaptor.getIndices(), rewriter);
514 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
517 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
521 auto maskOp = loadTileSliceOp.getMask();
523 auto tileVectorType = loadTileSliceOp.getVectorType();
525 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
528 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
529 tileId, tileSliceI32);
533 rewriter.
replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
540 struct StoreTileSliceConversion
541 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
542 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
545 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
546 arm_sme::StoreTileSliceOp::Adaptor adaptor,
548 auto loc = storeTileSliceOp.getLoc();
549 auto tileVectorType = storeTileSliceOp.getVectorType();
551 auto tileId = getTileIdOrError(storeTileSliceOp);
556 Value ptr = this->getStridedElementPtr(
557 loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558 adaptor.getIndices(), rewriter);
560 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
563 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
566 auto maskOp = storeTileSliceOp.getMask();
568 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
572 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
574 tileId, tileSliceI32));
581 struct InsertTileSliceConversion
582 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
583 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
586 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
587 arm_sme::InsertTileSliceOp::Adaptor adaptor,
589 auto loc = insertTileSliceOp.getLoc();
590 auto tileType = insertTileSliceOp.getTileType();
592 auto tileId = getTileIdOrError(insertTileSliceOp);
596 auto tileSlice = insertTileSliceOp.getTileSliceIndex();
599 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
603 auto one = rewriter.
create<arith::ConstantOp>(
608 auto allActiveMask = rewriter.
create<vector::SplatOp>(loc, predTy, one);
611 switch (insertTileSliceOp.getLayout()) {
612 case arm_sme::TileSliceLayout::Horizontal:
613 rewriter.
create<arm_sme::aarch64_sme_write_horiz>(
614 loc, tileId, tileSliceI32, allActiveMask,
615 insertTileSliceOp.getVector());
617 case arm_sme::TileSliceLayout::Vertical:
618 rewriter.
create<arm_sme::aarch64_sme_write_vert>(
619 loc, tileId, tileSliceI32, allActiveMask,
620 insertTileSliceOp.getVector());
626 rewriter.
replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
633 struct ExtractTileSliceConversion
634 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
635 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
638 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
640 auto loc = extractTileSlice.getLoc();
641 auto sliceType = extractTileSlice.getSliceType();
642 auto sliceIndex = extractTileSlice.getTileSliceIndex();
644 auto tileId = getTileIdOrError(extractTileSlice);
649 auto predicateType = sliceType.cloneWith({}, rewriter.
getI1Type());
650 auto allTruePredicate = rewriter.
create<arith::ConstantOp>(
654 auto zeroVector = rewriter.
create<arith::ConstantOp>(
658 auto sliceIndexI32 = rewriter.
create<arith::IndexCastOp>(
662 switch (extractTileSlice.getLayout()) {
663 case arm_sme::TileSliceLayout::Horizontal:
665 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
668 case arm_sme::TileSliceLayout::Vertical:
670 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
693 struct OuterProductOpConversion
694 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
695 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
698 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
699 arm_sme::OuterProductOp::Adaptor adaptor,
701 auto tileId = getTileIdOrError(outerProductOp);
705 auto isSupportedType = [](VectorType vectorType) {
719 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
722 auto elementType = vectorType.getElementType();
724 if (!elementType.isF16() && !elementType.isBF16() &&
725 !elementType.isF32() && !elementType.isF64())
729 vectorType.getElementTypeBitWidth();
730 return vectorType.getShape() ==
735 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
736 return outerProductOp.emitError(
"unsupported kind");
738 auto resultVectorType = outerProductOp.getResultType();
739 if (!isSupportedType(resultVectorType))
740 return outerProductOp.emitError(
"unsupported type");
742 auto loc = outerProductOp.getLoc();
744 Value acc = outerProductOp.getAcc();
747 auto zero = rewriter.
create<arm_sme::ZeroOp>(loc, resultVectorType);
748 zero.setTileId(tileId);
752 Value lhsMask = outerProductOp.getLhsMask();
753 Value rhsMask = outerProductOp.getRhsMask();
755 if (!lhsMask || !rhsMask) {
757 outerProductOp.getLhsType().cloneWith({}, rewriter.
getI1Type());
758 Value allActiveMask = rewriter.
create<arith::ConstantOp>(
760 lhsMask = allActiveMask;
761 rhsMask = allActiveMask;
765 rewriter.
create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
766 outerProductOp.getLhs(),
767 outerProductOp.getRhs());
778 template <
class OuterProductW
ideningOp,
class OuterProductW
ideningIntrOp>
779 struct OuterProductWideningOpConversion
780 :
public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
781 using ConvertArmSMEOpToLLVMPattern<
782 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
785 matchAndRewrite(OuterProductWideningOp op,
786 typename OuterProductWideningOp::Adaptor adaptor,
788 auto tileId = getTileIdOrError(op);
792 auto loc = op.getLoc();
793 Value acc = op.getAcc();
796 auto zero = rewriter.
create<arm_sme::ZeroOp>(loc, op.getResultType());
797 zero.setTileId(tileId);
801 Value lhsMask = op.getLhsMask();
802 Value rhsMask = op.getRhsMask();
803 if (!lhsMask || !rhsMask) {
804 auto predTy = op.getLhsType().cloneWith({}, rewriter.
getI1Type());
805 Value allActiveMask = rewriter.
create<arith::ConstantOp>(
807 lhsMask = allActiveMask;
808 rhsMask = allActiveMask;
811 rewriter.
create<OuterProductWideningIntrOp>(
812 loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
833 struct StreamingVLOpConversion
834 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
835 RequiresSpillsAndFills::No> {
836 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
839 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
840 arm_sme::StreamingVLOp::Adaptor adaptor,
842 auto loc = streamingVlOp.getLoc();
845 switch (streamingVlOp.getTypeSize()) {
846 case arm_sme::TypeSize::Byte:
847 return rewriter.
create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
848 case arm_sme::TypeSize::Half:
849 return rewriter.
create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
850 case arm_sme::TypeSize::Word:
851 return rewriter.
create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
852 case arm_sme::TypeSize::Double:
853 return rewriter.
create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
855 llvm_unreachable(
"unknown type size in StreamingVLOpConversion");
858 streamingVlOp, rewriter.
getIndexType(), intrOp->getResult(0));
865 static void mergeConsecutiveTileZerosInBlock(
Block *block) {
866 uint32_t mergedZeroMask = 0;
868 auto replaceMergedZeroOps = [&] {
869 auto cleanup = llvm::make_scope_exit([&] {
871 zeroOpsToMerge.clear();
873 if (zeroOpsToMerge.size() <= 1)
876 rewriter.
create<arm_sme::aarch64_sme_zero>(
877 zeroOpsToMerge.front().getLoc(),
879 for (
auto zeroOp : zeroOpsToMerge)
883 if (
auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
884 mergedZeroMask |= zeroOp.getTileMask();
885 zeroOpsToMerge.push_back(zeroOp);
887 replaceMergedZeroOps();
890 replaceMergedZeroOps();
897 struct ConvertArmSMEToLLVMPass
898 :
public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
899 ConvertArmSMEToLLVMPass(
bool dumpTileLiveRanges) {
900 this->dumpTileLiveRanges = dumpTileLiveRanges;
902 void runOnOperation()
override {
903 auto function = getOperation();
906 return signalPassFailure();
917 function->walk(mergeConsecutiveTileZerosInBlock);
923 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
926 auto isSMETileType = [](
Type type) {
927 return arm_sme::isValidSMETileVectorType(type);
931 op->emitOpError(
"unexpected operation with SME tile type after "
932 "conversion to LLVM");
944 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
945 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
946 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
947 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
948 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
949 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
950 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
951 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
952 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
953 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
954 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
955 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
956 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
957 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
958 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
959 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
960 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
961 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
962 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
963 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
964 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
965 arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
966 arm_sme::aarch64_sme_cntsd>();
969 vector::VectorDialect, scf::SCFDialect,
970 memref::MemRefDialect>();
974 target.
addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
975 UnrealizedConversionCastOp>();
980 converter.
addConversion([&](VectorType type) -> std::optional<Type> {
988 addArmSMEConversionPatterns<
989 LoadTileSliceConversion, ExtractTileSliceConversion,
990 InsertTileSliceConversion, StoreTileSliceConversion,
991 StreamingVLOpConversion, OuterProductOpConversion,
992 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
993 arm_sme::aarch64_sme_mopa_wide>,
994 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
995 arm_sme::aarch64_sme_mops_wide>,
996 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
997 arm_sme::aarch64_sme_smopa_za32>,
998 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
999 arm_sme::aarch64_sme_smops_za32>,
1000 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
1001 arm_sme::aarch64_sme_umopa_za32>,
1002 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1003 arm_sme::aarch64_sme_umops_za32>,
1004 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1005 arm_sme::aarch64_sme_smopa_wide>,
1006 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1007 arm_sme::aarch64_sme_smops_wide>,
1008 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1009 arm_sme::aarch64_sme_umopa_wide>,
1010 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1011 arm_sme::aarch64_sme_umops_wide>,
1012 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1013 arm_sme::aarch64_sme_sumopa_wide>,
1014 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1015 arm_sme::aarch64_sme_sumops_wide>,
1016 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1017 arm_sme::aarch64_sme_usmopa_wide>,
1018 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1019 arm_sme::aarch64_sme_usmops_wide>,
1020 ZeroOpConversion>(
patterns, converter);
1023 std::unique_ptr<Pass>
1025 return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for operation conversions targeting the LLVM IR dialect.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
constexpr unsigned MinStreamingVectorLengthInBits
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.