27 #include "llvm/ADT/ScopeExit.h"
30 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
31 #include "mlir/Conversion/Passes.h.inc"
38 static constexpr StringLiteral kInMemoryTileIdAttr(
"arm_sme.in_memory_tile_id");
41 static Operation *createLoadTileSliceIntrinsic(
43 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
44 IntegerAttr tileId,
Value tileSliceI32) {
45 if (layout == arm_sme::TileSliceLayout::Horizontal) {
47 case arm_sme::ArmSMETileType::ZAB:
48 return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr,
49 tileId, tileSliceI32);
50 case arm_sme::ArmSMETileType::ZAH:
51 return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr,
52 tileId, tileSliceI32);
53 case arm_sme::ArmSMETileType::ZAS:
54 return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr,
55 tileId, tileSliceI32);
56 case arm_sme::ArmSMETileType::ZAD:
57 return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr,
58 tileId, tileSliceI32);
59 case arm_sme::ArmSMETileType::ZAQ:
60 return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr,
61 tileId, tileSliceI32);
65 case arm_sme::ArmSMETileType::ZAB:
66 return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr,
67 tileId, tileSliceI32);
68 case arm_sme::ArmSMETileType::ZAH:
69 return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr,
70 tileId, tileSliceI32);
71 case arm_sme::ArmSMETileType::ZAS:
72 return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr,
73 tileId, tileSliceI32);
74 case arm_sme::ArmSMETileType::ZAD:
75 return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr,
76 tileId, tileSliceI32);
77 case arm_sme::ArmSMETileType::ZAQ:
78 return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr,
79 tileId, tileSliceI32);
83 llvm_unreachable(
"unknown type in createLoadTileSliceIntrinsic");
87 static Operation *createStoreTileSliceIntrinsic(
89 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
90 IntegerAttr tileId,
Value tileSliceI32) {
91 if (layout == arm_sme::TileSliceLayout::Horizontal) {
93 case arm_sme::ArmSMETileType::ZAB:
94 return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr,
95 tileId, tileSliceI32);
96 case arm_sme::ArmSMETileType::ZAH:
97 return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr,
98 tileId, tileSliceI32);
99 case arm_sme::ArmSMETileType::ZAS:
100 return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr,
101 tileId, tileSliceI32);
102 case arm_sme::ArmSMETileType::ZAD:
103 return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr,
104 tileId, tileSliceI32);
105 case arm_sme::ArmSMETileType::ZAQ:
106 return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr,
107 tileId, tileSliceI32);
111 case arm_sme::ArmSMETileType::ZAB:
112 return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr,
113 tileId, tileSliceI32);
114 case arm_sme::ArmSMETileType::ZAH:
115 return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr,
116 tileId, tileSliceI32);
117 case arm_sme::ArmSMETileType::ZAS:
118 return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr,
119 tileId, tileSliceI32);
120 case arm_sme::ArmSMETileType::ZAD:
121 return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr,
122 tileId, tileSliceI32);
123 case arm_sme::ArmSMETileType::ZAQ:
124 return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr,
125 tileId, tileSliceI32);
128 llvm_unreachable(
"unknown type in createStoreTileSliceIntrinsic");
131 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
132 auto tileId = op.getTileId();
135 "expected tile ID to be allocated before conversion to LLVM");
141 static memref::AllocaOp
143 FunctionOpInterface func,
144 arm_sme::ArmSMETileOpInterface tileOp) {
145 RewriterBase::InsertionGuard g(rewriter);
149 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
150 auto tileElementType = tileOp.getTileType().getElementType();
152 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
156 auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp);
157 auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType,
163 static memref::AllocaOp getOrCreateAllocaForTile(
165 arm_sme::ArmSMETileOpInterface tileOp,
unsigned tileId) {
168 for (
auto &op : func.getBlocks().front()) {
169 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
172 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
173 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
176 if (inMemoryTileId.getInt() == tileId)
180 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
181 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
238 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
242 typeConverter, benefit) {}
247 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
249 if (!tileOp.isInMemoryTile())
253 "failed to allocate SME virtual tile to operation, tile value will go "
254 "through memory, expect degraded performance");
258 auto loc = tileOp.getLoc();
259 auto func = tileOp->getParentOfType<FunctionOpInterface>();
260 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
261 tileOp.getTileId().getInt());
266 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
268 VectorType tileVectorType = tileOp.getTileType();
270 auto swapInMemoryTileWithSMETileZero = [&] {
271 emitFullTileSwap(rewriter, loc, tileAlloca,
282 swapInMemoryTileWithSMETileZero();
285 swapInMemoryTileWithSMETileZero();
294 auto llvmType = getTypeConverter()->convertType(tileMemory.
getType());
296 UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory);
298 auto sliceIndexI64 = arith::IndexCastOp::create(
299 rewriter, loc, rewriter.
getI64Type(), sliceIndex);
302 llvm::cast<MemRefType>(tileMemory.
getType()), descriptor.getResult(0),
303 {sliceIndexI64, zero});
309 arm_sme::ArmSMETileType tileType, VectorType sliceType,
310 IntegerAttr tileId,
Value sliceIndex)
const {
312 auto sliceIndexI32 = arith::IndexCastOp::create(
313 rewriter, loc, rewriter.
getI32Type(), sliceIndex);
315 auto predicateType = sliceType.clone(rewriter.
getI1Type());
316 auto allTruePredicate = arith::ConstantOp::create(
319 auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType);
322 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
324 auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(
325 rewriter, loc, sliceType, padVector, allTruePredicate, tileId,
328 createLoadTileSliceIntrinsic(
329 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330 allTruePredicate, slicePtr, tileId, sliceIndexI32);
333 vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca,
340 arm_sme::ArmSMETileType tileType, VectorType sliceType,
341 IntegerAttr tileId)
const {
342 RewriterBase::InsertionGuard guard(rewriter);
348 arith::MulIOp::create(rewriter, loc, minNumElts,
349 vector::VectorScaleOp::create(rewriter, loc));
352 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
355 auto sliceIndex = forOp.getInductionVar();
356 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
361 enum class RequiresSpillsAndFills { Yes, No };
367 template <
typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
368 RequiresSpillsAndFills::Yes>
370 using ArmSMEOp = SourceOp;
373 static constexpr
bool requiresSpillsAndFillsConversion() {
374 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
378 template <
typename Pattern>
384 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
385 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
386 typename Pattern::ArmSMEOp>,
387 typename Pattern::ArmSMEOp>) {
390 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
391 Pattern::ArmSMEOp::getOperationName(), typeConverter,
398 template <
typename... Patterns>
402 (addArmSMEConversionPattern<Patterns>(
patterns, typeConverter), ...);
420 struct ZeroOpConversion :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
421 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
424 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
426 auto loc = zero.getLoc();
428 auto tileId = getTileIdOrError(zero);
436 arm_sme::ArmSMETileType tileType =
438 auto baseMaskForSize = [&] {
440 case arm_sme::ArmSMETileType::ZAB:
444 case arm_sme::ArmSMETileType::ZAH:
449 case arm_sme::ArmSMETileType::ZAS:
454 case arm_sme::ArmSMETileType::ZAD:
459 llvm_unreachable(
"bad element size");
484 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
485 arm_sme::aarch64_sme_zero::create(rewriter, loc,
499 struct LoadTileSliceConversion
500 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
501 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
504 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
505 arm_sme::LoadTileSliceOp::Adaptor adaptor,
507 auto loc = loadTileSliceOp.getLoc();
508 auto tileId = getTileIdOrError(loadTileSliceOp);
513 rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
514 adaptor.getIndices());
516 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
519 auto tileSliceI32 = arith::IndexCastUIOp::create(
520 rewriter, loc, rewriter.
getI32Type(), tileSlice);
523 auto maskOp = loadTileSliceOp.getMask();
525 auto tileVectorType = loadTileSliceOp.getVectorType();
527 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
530 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
531 tileId, tileSliceI32);
535 rewriter.
replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
542 struct StoreTileSliceConversion
543 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
544 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
547 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
548 arm_sme::StoreTileSliceOp::Adaptor adaptor,
550 auto loc = storeTileSliceOp.getLoc();
551 auto tileVectorType = storeTileSliceOp.getVectorType();
553 auto tileId = getTileIdOrError(storeTileSliceOp);
559 rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
560 adaptor.getIndices());
562 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
565 auto tileSliceI32 = arith::IndexCastUIOp::create(
566 rewriter, loc, rewriter.
getI32Type(), tileSlice);
568 auto maskOp = storeTileSliceOp.getMask();
570 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
574 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
576 tileId, tileSliceI32));
583 struct InsertTileSliceConversion
584 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
585 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
588 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
589 arm_sme::InsertTileSliceOp::Adaptor adaptor,
591 auto loc = insertTileSliceOp.getLoc();
592 auto tileType = insertTileSliceOp.getTileType();
594 auto tileId = getTileIdOrError(insertTileSliceOp);
598 auto tileSlice = insertTileSliceOp.getTileSliceIndex();
601 auto tileSliceI32 = arith::IndexCastUIOp::create(
602 rewriter, loc, rewriter.
getI32Type(), tileSlice);
605 auto one = arith::ConstantOp::create(
611 vector::BroadcastOp::create(rewriter, loc, predTy, one);
614 switch (insertTileSliceOp.getLayout()) {
615 case arm_sme::TileSliceLayout::Horizontal:
616 arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId,
617 tileSliceI32, allActiveMask,
618 insertTileSliceOp.getVector());
620 case arm_sme::TileSliceLayout::Vertical:
621 arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId,
622 tileSliceI32, allActiveMask,
623 insertTileSliceOp.getVector());
629 rewriter.
replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
636 struct ExtractTileSliceConversion
637 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
638 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
641 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
643 auto loc = extractTileSlice.getLoc();
644 auto sliceType = extractTileSlice.getSliceType();
645 auto sliceIndex = extractTileSlice.getTileSliceIndex();
647 auto tileId = getTileIdOrError(extractTileSlice);
652 auto predicateType = sliceType.cloneWith({}, rewriter.
getI1Type());
653 auto allTruePredicate = arith::ConstantOp::create(
657 auto zeroVector = arith::ConstantOp::create(
658 rewriter, loc, sliceType, rewriter.
getZeroAttr(sliceType));
661 auto sliceIndexI32 = arith::IndexCastOp::create(
662 rewriter, loc, rewriter.
getI32Type(), sliceIndex);
665 switch (extractTileSlice.getLayout()) {
666 case arm_sme::TileSliceLayout::Horizontal:
668 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
671 case arm_sme::TileSliceLayout::Vertical:
673 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
696 struct OuterProductOpConversion
697 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
698 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
701 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
702 arm_sme::OuterProductOp::Adaptor adaptor,
704 auto tileId = getTileIdOrError(outerProductOp);
708 auto isSupportedType = [](VectorType vectorType) {
722 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
725 auto elementType = vectorType.getElementType();
727 if (!elementType.isF16() && !elementType.isBF16() &&
728 !elementType.isF32() && !elementType.isF64())
732 vectorType.getElementTypeBitWidth();
733 return vectorType.getShape() ==
738 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
739 return outerProductOp.emitError(
"unsupported kind");
741 auto resultVectorType = outerProductOp.getResultType();
742 if (!isSupportedType(resultVectorType))
743 return outerProductOp.emitError(
"unsupported type");
745 auto loc = outerProductOp.getLoc();
747 Value acc = outerProductOp.getAcc();
750 auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType);
751 zero.setTileId(tileId);
755 Value lhsMask = outerProductOp.getLhsMask();
756 Value rhsMask = outerProductOp.getRhsMask();
758 if (!lhsMask || !rhsMask) {
760 outerProductOp.getLhsType().cloneWith({}, rewriter.
getI1Type());
761 Value allActiveMask = arith::ConstantOp::create(
763 lhsMask = allActiveMask;
764 rhsMask = allActiveMask;
768 arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask,
769 outerProductOp.getLhs(),
770 outerProductOp.getRhs());
781 template <
class OuterProductW
ideningOp,
class OuterProductW
ideningIntrOp>
782 struct OuterProductWideningOpConversion
783 :
public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
784 using ConvertArmSMEOpToLLVMPattern<
785 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
788 matchAndRewrite(OuterProductWideningOp op,
789 typename OuterProductWideningOp::Adaptor adaptor,
791 auto tileId = getTileIdOrError(op);
795 auto loc = op.getLoc();
796 Value acc = op.getAcc();
799 auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType());
800 zero.setTileId(tileId);
804 Value lhsMask = op.getLhsMask();
805 Value rhsMask = op.getRhsMask();
806 if (!lhsMask || !rhsMask) {
807 auto predTy = op.getLhsType().cloneWith({}, rewriter.
getI1Type());
808 Value allActiveMask = arith::ConstantOp::create(
810 lhsMask = allActiveMask;
811 rhsMask = allActiveMask;
814 OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask,
815 adaptor.getLhs(), adaptor.getRhs());
836 struct StreamingVLOpConversion
837 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
838 RequiresSpillsAndFills::No> {
839 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
842 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
843 arm_sme::StreamingVLOp::Adaptor adaptor,
845 auto loc = streamingVlOp.getLoc();
848 switch (streamingVlOp.getTypeSize()) {
849 case arm_sme::TypeSize::Byte:
850 return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type);
851 case arm_sme::TypeSize::Half:
852 return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type);
853 case arm_sme::TypeSize::Word:
854 return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type);
855 case arm_sme::TypeSize::Double:
856 return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
858 llvm_unreachable(
"unknown type size in StreamingVLOpConversion");
861 streamingVlOp, rewriter.
getIndexType(), intrOp->getResult(0));
868 static void mergeConsecutiveTileZerosInBlock(
Block *block) {
869 uint32_t mergedZeroMask = 0;
871 auto replaceMergedZeroOps = [&] {
872 auto cleanup = llvm::make_scope_exit([&] {
874 zeroOpsToMerge.clear();
876 if (zeroOpsToMerge.size() <= 1)
879 arm_sme::aarch64_sme_zero::create(
880 rewriter, zeroOpsToMerge.front().getLoc(),
882 for (
auto zeroOp : zeroOpsToMerge)
886 if (
auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
887 mergedZeroMask |= zeroOp.getTileMask();
888 zeroOpsToMerge.push_back(zeroOp);
890 replaceMergedZeroOps();
893 replaceMergedZeroOps();
900 struct ConvertArmSMEToLLVMPass
901 :
public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
902 ConvertArmSMEToLLVMPass(
bool dumpTileLiveRanges) {
903 this->dumpTileLiveRanges = dumpTileLiveRanges;
905 void runOnOperation()
override {
906 auto function = getOperation();
909 return signalPassFailure();
920 function->walk(mergeConsecutiveTileZerosInBlock);
926 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
929 auto isSMETileType = [](
Type type) {
930 return arm_sme::isValidSMETileVectorType(type);
934 op->emitOpError(
"unexpected operation with SME tile type after "
935 "conversion to LLVM");
947 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
948 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
949 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
950 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
951 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
952 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
953 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
954 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
955 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
956 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
957 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
958 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
959 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
960 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
961 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
962 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
963 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
964 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
965 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
966 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
967 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
968 arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
969 arm_sme::aarch64_sme_cntsd>();
972 vector::VectorDialect, scf::SCFDialect,
973 memref::MemRefDialect>();
977 target.
addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
978 UnrealizedConversionCastOp>();
983 converter.
addConversion([&](VectorType type) -> std::optional<Type> {
991 addArmSMEConversionPatterns<
992 LoadTileSliceConversion, ExtractTileSliceConversion,
993 InsertTileSliceConversion, StoreTileSliceConversion,
994 StreamingVLOpConversion, OuterProductOpConversion,
995 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
996 arm_sme::aarch64_sme_mopa_wide>,
997 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
998 arm_sme::aarch64_sme_mops_wide>,
999 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
1000 arm_sme::aarch64_sme_smopa_za32>,
1001 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
1002 arm_sme::aarch64_sme_smops_za32>,
1003 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
1004 arm_sme::aarch64_sme_umopa_za32>,
1005 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1006 arm_sme::aarch64_sme_umops_za32>,
1007 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1008 arm_sme::aarch64_sme_smopa_wide>,
1009 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1010 arm_sme::aarch64_sme_smops_wide>,
1011 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1012 arm_sme::aarch64_sme_umopa_wide>,
1013 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1014 arm_sme::aarch64_sme_umops_wide>,
1015 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1016 arm_sme::aarch64_sme_sumopa_wide>,
1017 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1018 arm_sme::aarch64_sme_sumops_wide>,
1019 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1020 arm_sme::aarch64_sme_usmopa_wide>,
1021 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1022 arm_sme::aarch64_sme_usmops_wide>,
1023 ZeroOpConversion>(
patterns, converter);
1026 std::unique_ptr<Pass>
1028 return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for operation conversions targeting the LLVM IR dialect.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
constexpr unsigned MinStreamingVectorLengthInBits
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.