30 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
31 #include "mlir/Conversion/Passes.h.inc"
38 static constexpr StringLiteral kInMemoryTileIdAttr(
"arm_sme.in_memory_tile_id");
41 static Operation *createLoadTileSliceIntrinsic(
43 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
44 IntegerAttr tileId,
Value tileSliceI32) {
45 if (layout == arm_sme::TileSliceLayout::Horizontal) {
47 case arm_sme::ArmSMETileType::ZAB:
48 return rewriter.
create<arm_sme::aarch64_sme_ld1b_horiz>(
49 loc, maskOp, ptr, tileId, tileSliceI32);
50 case arm_sme::ArmSMETileType::ZAH:
51 return rewriter.
create<arm_sme::aarch64_sme_ld1h_horiz>(
52 loc, maskOp, ptr, tileId, tileSliceI32);
53 case arm_sme::ArmSMETileType::ZAS:
54 return rewriter.
create<arm_sme::aarch64_sme_ld1w_horiz>(
55 loc, maskOp, ptr, tileId, tileSliceI32);
56 case arm_sme::ArmSMETileType::ZAD:
57 return rewriter.
create<arm_sme::aarch64_sme_ld1d_horiz>(
58 loc, maskOp, ptr, tileId, tileSliceI32);
59 case arm_sme::ArmSMETileType::ZAQ:
60 return rewriter.
create<arm_sme::aarch64_sme_ld1q_horiz>(
61 loc, maskOp, ptr, tileId, tileSliceI32);
65 case arm_sme::ArmSMETileType::ZAB:
66 return rewriter.
create<arm_sme::aarch64_sme_ld1b_vert>(
67 loc, maskOp, ptr, tileId, tileSliceI32);
68 case arm_sme::ArmSMETileType::ZAH:
69 return rewriter.
create<arm_sme::aarch64_sme_ld1h_vert>(
70 loc, maskOp, ptr, tileId, tileSliceI32);
71 case arm_sme::ArmSMETileType::ZAS:
72 return rewriter.
create<arm_sme::aarch64_sme_ld1w_vert>(
73 loc, maskOp, ptr, tileId, tileSliceI32);
74 case arm_sme::ArmSMETileType::ZAD:
75 return rewriter.
create<arm_sme::aarch64_sme_ld1d_vert>(
76 loc, maskOp, ptr, tileId, tileSliceI32);
77 case arm_sme::ArmSMETileType::ZAQ:
78 return rewriter.
create<arm_sme::aarch64_sme_ld1q_vert>(
79 loc, maskOp, ptr, tileId, tileSliceI32);
86 static Operation *createStoreTileSliceIntrinsic(
88 arm_sme::TileSliceLayout layout,
Value maskOp,
Value ptr,
89 IntegerAttr tileId,
Value tileSliceI32) {
90 if (layout == arm_sme::TileSliceLayout::Horizontal) {
92 case arm_sme::ArmSMETileType::ZAB:
93 return rewriter.
create<arm_sme::aarch64_sme_st1b_horiz>(
94 loc, maskOp, ptr, tileId, tileSliceI32);
95 case arm_sme::ArmSMETileType::ZAH:
96 return rewriter.
create<arm_sme::aarch64_sme_st1h_horiz>(
97 loc, maskOp, ptr, tileId, tileSliceI32);
98 case arm_sme::ArmSMETileType::ZAS:
99 return rewriter.
create<arm_sme::aarch64_sme_st1w_horiz>(
100 loc, maskOp, ptr, tileId, tileSliceI32);
101 case arm_sme::ArmSMETileType::ZAD:
102 return rewriter.
create<arm_sme::aarch64_sme_st1d_horiz>(
103 loc, maskOp, ptr, tileId, tileSliceI32);
104 case arm_sme::ArmSMETileType::ZAQ:
105 return rewriter.
create<arm_sme::aarch64_sme_st1q_horiz>(
106 loc, maskOp, ptr, tileId, tileSliceI32);
110 case arm_sme::ArmSMETileType::ZAB:
111 return rewriter.
create<arm_sme::aarch64_sme_st1b_vert>(
112 loc, maskOp, ptr, tileId, tileSliceI32);
113 case arm_sme::ArmSMETileType::ZAH:
114 return rewriter.
create<arm_sme::aarch64_sme_st1h_vert>(
115 loc, maskOp, ptr, tileId, tileSliceI32);
116 case arm_sme::ArmSMETileType::ZAS:
117 return rewriter.
create<arm_sme::aarch64_sme_st1w_vert>(
118 loc, maskOp, ptr, tileId, tileSliceI32);
119 case arm_sme::ArmSMETileType::ZAD:
120 return rewriter.
create<arm_sme::aarch64_sme_st1d_vert>(
121 loc, maskOp, ptr, tileId, tileSliceI32);
122 case arm_sme::ArmSMETileType::ZAQ:
123 return rewriter.
create<arm_sme::aarch64_sme_st1q_vert>(
124 loc, maskOp, ptr, tileId, tileSliceI32);
129 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
130 auto tileId = op.getTileId();
133 "expected tile ID to be allocated before conversion to LLVM");
139 static memref::AllocaOp
141 FunctionOpInterface func,
142 arm_sme::ArmSMETileOpInterface tileOp) {
143 RewriterBase::InsertionGuard g(rewriter);
147 auto vscale = rewriter.
create<vector::VectorScaleOp>(loc);
148 auto tileElementType = tileOp.getTileType().getElementType();
150 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
153 rewriter.
create<arith::ConstantIndexOp>(loc, minElements);
154 auto vectorLen = rewriter.
create<arith::MulIOp>(loc, vscale, minElementsOp);
155 auto alloca = rewriter.
create<memref::AllocaOp>(
156 loc, memrefType,
ValueRange{vectorLen, vectorLen});
161 static memref::AllocaOp getOrCreateAllocaForTile(
163 arm_sme::ArmSMETileOpInterface tileOp,
unsigned tileId) {
166 for (
auto &op : func.getBlocks().front()) {
167 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
170 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
171 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
174 if (inMemoryTileId.getInt() == tileId)
178 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
179 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
236 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
240 typeConverter, benefit) {}
245 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
247 if (!tileOp.isInMemoryTile())
251 "failed to allocate SME virtual tile to operation, tile value will go "
252 "through memory, expect degraded performance");
256 auto loc = tileOp.getLoc();
257 auto func = tileOp->getParentOfType<FunctionOpInterface>();
258 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
259 tileOp.getTileId().getInt());
264 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
266 VectorType tileVectorType = tileOp.getTileType();
268 auto swapInMemoryTileWithSMETileZero = [&] {
269 emitFullTileSwap(rewriter, loc, tileAlloca,
280 swapInMemoryTileWithSMETileZero();
283 swapInMemoryTileWithSMETileZero();
292 auto llvmType = getTypeConverter()->convertType(tileMemory.
getType());
294 rewriter.
create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
295 auto zero = rewriter.
create<arith::ConstantIntOp>(loc, 0, 64);
296 auto sliceIndexI64 = rewriter.
create<arith::IndexCastOp>(
298 return getStridedElementPtr(
299 loc, llvm::cast<MemRefType>(tileMemory.
getType()),
300 descriptor.getResult(0), {sliceIndexI64, zero},
307 arm_sme::ArmSMETileType tileType, VectorType sliceType,
308 IntegerAttr tileId,
Value sliceIndex)
const {
310 auto sliceIndexI32 = rewriter.
create<arith::IndexCastOp>(
314 auto allTruePredicate = rewriter.
create<arith::ConstantOp>(
317 auto padVector = rewriter.
create<LLVM::UndefOp>(loc, sliceType);
320 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
322 auto currentTileSlice = rewriter.
create<arm_sme::aarch64_sme_read_horiz>(
323 loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
325 createLoadTileSliceIntrinsic(
326 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
327 allTruePredicate, slicePtr, tileId, sliceIndexI32);
329 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
330 rewriter.
create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
337 arm_sme::ArmSMETileType tileType, VectorType sliceType,
338 IntegerAttr tileId)
const {
339 RewriterBase::InsertionGuard guard(rewriter);
342 rewriter.
create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
343 auto lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
344 auto upperBound = rewriter.
create<arith::MulIOp>(
345 loc, minNumElts, rewriter.
create<vector::VectorScaleOp>(loc));
346 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
347 auto forOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
350 auto sliceIndex = forOp.getInductionVar();
351 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
356 enum class RequiresSpillsAndFills { Yes, No };
362 template <
typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
363 RequiresSpillsAndFills::Yes>
365 using ArmSMEOp = SourceOp;
368 static constexpr
bool requiresSpillsAndFillsConversion() {
369 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
373 template <
typename Pattern>
379 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
380 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
381 typename Pattern::ArmSMEOp>,
382 typename Pattern::ArmSMEOp>) {
385 patterns.
add<ConvertArmSMESpillsAndFillsToLLVM>(
386 Pattern::ArmSMEOp::getOperationName(), typeConverter,
393 template <
typename... Patterns>
397 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
415 struct ZeroOpConversion :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
416 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
419 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
421 auto loc = zero.getLoc();
423 auto tileId = getTileIdOrError(zero);
431 arm_sme::ArmSMETileType tileType =
433 auto baseMaskForSize = [&] {
435 case arm_sme::ArmSMETileType::ZAB:
439 case arm_sme::ArmSMETileType::ZAH:
444 case arm_sme::ArmSMETileType::ZAS:
449 case arm_sme::ArmSMETileType::ZAD:
454 llvm_unreachable(
"bad element size");
479 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
480 rewriter.
create<arm_sme::aarch64_sme_zero>(
491 struct LoadTileSliceConversion
492 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
493 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
496 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
497 arm_sme::LoadTileSliceOp::Adaptor adaptor,
499 auto loc = loadTileSliceOp.getLoc();
500 auto tileId = getTileIdOrError(loadTileSliceOp);
504 Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
506 adaptor.getIndices(), rewriter);
508 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
511 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
515 auto maskOp = loadTileSliceOp.getMask();
517 auto tileVectorType = loadTileSliceOp.getVectorType();
519 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
522 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
523 tileId, tileSliceI32);
527 rewriter.
replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
534 struct StoreTileSliceConversion
535 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
536 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
539 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
540 arm_sme::StoreTileSliceOp::Adaptor adaptor,
542 auto loc = storeTileSliceOp.getLoc();
543 auto tileVectorType = storeTileSliceOp.getVectorType();
545 auto tileId = getTileIdOrError(storeTileSliceOp);
550 Value ptr = this->getStridedElementPtr(
551 loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
552 adaptor.getIndices(), rewriter);
554 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
557 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
560 auto maskOp = storeTileSliceOp.getMask();
562 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
566 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
568 tileId, tileSliceI32));
575 struct MoveVectorToTileSliceConversion
576 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
577 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
580 matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp,
581 arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor,
583 auto loc = moveVectorToTileSliceOp.getLoc();
584 auto tileType = moveVectorToTileSliceOp.getTileType();
586 auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
590 auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
593 auto tileSliceI32 = rewriter.
create<arith::IndexCastUIOp>(
597 auto one = rewriter.
create<arith::ConstantOp>(
602 auto allActiveMask = rewriter.
create<vector::SplatOp>(loc, predTy, one);
605 switch (moveVectorToTileSliceOp.getLayout()) {
606 case arm_sme::TileSliceLayout::Horizontal:
607 rewriter.
create<arm_sme::aarch64_sme_write_horiz>(
608 loc, tileId, tileSliceI32, allActiveMask,
609 moveVectorToTileSliceOp.getVector());
611 case arm_sme::TileSliceLayout::Vertical:
612 rewriter.
create<arm_sme::aarch64_sme_write_vert>(
613 loc, tileId, tileSliceI32, allActiveMask,
614 moveVectorToTileSliceOp.getVector());
620 rewriter.
replaceOp(moveVectorToTileSliceOp,
621 moveVectorToTileSliceOp.getTile());
628 struct MoveTileSliceToVectorConversion
629 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
630 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
633 matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
636 auto loc = moveTileSliceToVector.getLoc();
637 auto sliceType = moveTileSliceToVector.getSliceType();
638 auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
640 auto tileId = getTileIdOrError(moveTileSliceToVector);
645 auto predicateType = sliceType.cloneWith({}, rewriter.
getI1Type());
646 auto allTruePredicate = rewriter.
create<arith::ConstantOp>(
650 auto zeroVector = rewriter.
create<arith::ConstantOp>(
654 auto sliceIndexI32 = rewriter.
create<arith::IndexCastOp>(
658 switch (moveTileSliceToVector.getLayout()) {
659 case arm_sme::TileSliceLayout::Horizontal:
661 moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
662 tileId, sliceIndexI32);
664 case arm_sme::TileSliceLayout::Vertical:
666 moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
667 tileId, sliceIndexI32);
689 struct OuterProductOpConversion
690 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
691 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
694 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
695 arm_sme::OuterProductOp::Adaptor adaptor,
697 auto tileId = getTileIdOrError(outerProductOp);
701 auto isSupportedType = [](VectorType vectorType) {
715 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
718 auto elementType = vectorType.getElementType();
720 if (!elementType.isF16() && !elementType.isBF16() &&
721 !elementType.isF32() && !elementType.isF64())
725 vectorType.getElementTypeBitWidth();
726 return vectorType.getShape() ==
731 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
732 return outerProductOp.emitError(
"unsupported kind");
734 auto resultVectorType = outerProductOp.getResultType();
735 if (!isSupportedType(resultVectorType))
736 return outerProductOp.emitError(
"unsupported type");
738 auto loc = outerProductOp.getLoc();
740 Value acc = outerProductOp.getAcc();
743 auto zero = rewriter.
create<arm_sme::ZeroOp>(loc, resultVectorType);
744 zero.setTileId(tileId);
748 Value lhsMask = outerProductOp.getLhsMask();
749 Value rhsMask = outerProductOp.getRhsMask();
751 if (!lhsMask || !rhsMask) {
753 outerProductOp.getLhsType().cloneWith({}, rewriter.
getI1Type());
754 Value allActiveMask = rewriter.
create<arith::ConstantOp>(
756 lhsMask = allActiveMask;
757 rhsMask = allActiveMask;
761 rewriter.
create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
762 outerProductOp.getLhs(),
763 outerProductOp.getRhs());
774 template <
class OuterProductW
ideningOp,
class OuterProductW
ideningIntrOp>
775 struct OuterProductWideningOpConversion
776 :
public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
777 using ConvertArmSMEOpToLLVMPattern<
778 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
781 matchAndRewrite(OuterProductWideningOp op,
782 typename OuterProductWideningOp::Adaptor adaptor,
784 auto tileId = getTileIdOrError(op);
789 Value acc = op.getAcc();
792 auto zero = rewriter.
create<arm_sme::ZeroOp>(loc, op.getResultType());
793 zero.setTileId(tileId);
797 Value lhsMask = op.getLhsMask();
798 Value rhsMask = op.getRhsMask();
799 if (!lhsMask || !rhsMask) {
800 auto predTy = op.getLhsType().cloneWith({}, rewriter.
getI1Type());
801 Value allActiveMask = rewriter.
create<arith::ConstantOp>(
803 lhsMask = allActiveMask;
804 rhsMask = allActiveMask;
807 rewriter.
create<OuterProductWideningIntrOp>(
808 loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
829 struct StreamingVLOpConversion
830 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
831 RequiresSpillsAndFills::No> {
832 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
835 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
836 arm_sme::StreamingVLOp::Adaptor adaptor,
838 auto loc = streamingVlOp.getLoc();
841 switch (streamingVlOp.getTypeSize()) {
842 case arm_sme::TypeSize::Byte:
843 return rewriter.
create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
844 case arm_sme::TypeSize::Half:
845 return rewriter.
create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
846 case arm_sme::TypeSize::Word:
847 return rewriter.
create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
848 case arm_sme::TypeSize::Double:
849 return rewriter.
create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
853 streamingVlOp, rewriter.
getIndexType(), intrOp->getResult(0));
862 struct ConvertArmSMEToLLVMPass
863 :
public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
864 ConvertArmSMEToLLVMPass(
bool dumpTileLiveRanges) {
865 this->dumpTileLiveRanges = dumpTileLiveRanges;
867 void runOnOperation()
override {
868 auto function = getOperation();
871 return signalPassFailure();
886 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
889 auto isSMETileType = [](
Type type) {
890 return arm_sme::isValidSMETileVectorType(type);
894 op->emitOpError(
"unexpected operation with SME tile type after "
895 "conversion to LLVM");
907 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
908 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
909 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
910 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
911 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
912 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
913 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
914 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
915 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
916 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
917 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
918 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
919 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
920 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
921 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
922 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
923 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
924 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
925 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
926 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
927 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
928 arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
929 arm_sme::aarch64_sme_cntsd>();
932 vector::VectorDialect, scf::SCFDialect,
933 memref::MemRefDialect>();
937 target.
addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
938 UnrealizedConversionCastOp>();
943 converter.
addConversion([&](VectorType type) -> std::optional<Type> {
951 addArmSMEConversionPatterns<
952 LoadTileSliceConversion, MoveTileSliceToVectorConversion,
953 MoveVectorToTileSliceConversion, StoreTileSliceConversion,
954 StreamingVLOpConversion, OuterProductOpConversion,
955 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
956 arm_sme::aarch64_sme_mopa_wide>,
957 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
958 arm_sme::aarch64_sme_mops_wide>,
959 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
960 arm_sme::aarch64_sme_smopa_za32>,
961 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
962 arm_sme::aarch64_sme_smops_za32>,
963 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
964 arm_sme::aarch64_sme_umopa_za32>,
965 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
966 arm_sme::aarch64_sme_umops_za32>,
967 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
968 arm_sme::aarch64_sme_smopa_wide>,
969 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
970 arm_sme::aarch64_sme_smops_wide>,
971 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
972 arm_sme::aarch64_sme_umopa_wide>,
973 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
974 arm_sme::aarch64_sme_umops_wide>,
975 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
976 arm_sme::aarch64_sme_sumopa_wide>,
977 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
978 arm_sme::aarch64_sme_sumops_wide>,
979 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
980 arm_sme::aarch64_sme_usmopa_wide>,
981 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
982 arm_sme::aarch64_sme_usmops_wide>,
983 ZeroOpConversion>(patterns, converter);
986 std::unique_ptr<Pass>
988 return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for operation conversions targeting the LLVM IR dialect.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
constexpr unsigned MinStreamingVectorLengthInBits
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.