27#include "llvm/ADT/ScopeExit.h"
30#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
31#include "mlir/Conversion/Passes.h.inc"
38static constexpr StringLiteral kInMemoryTileIdAttr(
"arm_sme.in_memory_tile_id");
41static Operation *createLoadTileSliceIntrinsic(
44 IntegerAttr tileId,
Value tileSliceI32) {
45 if (layout == arm_sme::TileSliceLayout::Horizontal) {
47 case arm_sme::ArmSMETileType::ZAB:
48 return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp,
ptr,
49 tileId, tileSliceI32);
50 case arm_sme::ArmSMETileType::ZAH:
51 return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp,
ptr,
52 tileId, tileSliceI32);
53 case arm_sme::ArmSMETileType::ZAS:
54 return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp,
ptr,
55 tileId, tileSliceI32);
56 case arm_sme::ArmSMETileType::ZAD:
57 return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp,
ptr,
58 tileId, tileSliceI32);
59 case arm_sme::ArmSMETileType::ZAQ:
60 return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp,
ptr,
61 tileId, tileSliceI32);
65 case arm_sme::ArmSMETileType::ZAB:
66 return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp,
ptr,
67 tileId, tileSliceI32);
68 case arm_sme::ArmSMETileType::ZAH:
69 return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp,
ptr,
70 tileId, tileSliceI32);
71 case arm_sme::ArmSMETileType::ZAS:
72 return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp,
ptr,
73 tileId, tileSliceI32);
74 case arm_sme::ArmSMETileType::ZAD:
75 return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp,
ptr,
76 tileId, tileSliceI32);
77 case arm_sme::ArmSMETileType::ZAQ:
78 return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp,
ptr,
79 tileId, tileSliceI32);
83 llvm_unreachable(
"unknown type in createLoadTileSliceIntrinsic");
87static Operation *createStoreTileSliceIntrinsic(
90 IntegerAttr tileId,
Value tileSliceI32) {
91 if (layout == arm_sme::TileSliceLayout::Horizontal) {
93 case arm_sme::ArmSMETileType::ZAB:
94 return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp,
ptr,
95 tileId, tileSliceI32);
96 case arm_sme::ArmSMETileType::ZAH:
97 return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp,
ptr,
98 tileId, tileSliceI32);
99 case arm_sme::ArmSMETileType::ZAS:
100 return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp,
ptr,
101 tileId, tileSliceI32);
102 case arm_sme::ArmSMETileType::ZAD:
103 return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp,
ptr,
104 tileId, tileSliceI32);
105 case arm_sme::ArmSMETileType::ZAQ:
106 return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp,
ptr,
107 tileId, tileSliceI32);
111 case arm_sme::ArmSMETileType::ZAB:
112 return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp,
ptr,
113 tileId, tileSliceI32);
114 case arm_sme::ArmSMETileType::ZAH:
115 return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp,
ptr,
116 tileId, tileSliceI32);
117 case arm_sme::ArmSMETileType::ZAS:
118 return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp,
ptr,
119 tileId, tileSliceI32);
120 case arm_sme::ArmSMETileType::ZAD:
121 return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp,
ptr,
122 tileId, tileSliceI32);
123 case arm_sme::ArmSMETileType::ZAQ:
124 return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp,
ptr,
125 tileId, tileSliceI32);
128 llvm_unreachable(
"unknown type in createStoreTileSliceIntrinsic");
135 "expected tile ID to be allocated before conversion to LLVM");
141static memref::AllocaOp
143 FunctionOpInterface
func,
149 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
150 auto tileElementType = tileOp.
getTileType().getElementType();
151 auto memrefType = MemRefType::get(
152 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
156 auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp);
157 auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType,
163static memref::AllocaOp getOrCreateAllocaForTile(
168 for (
auto &op :
func.getBlocks().front()) {
169 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
172 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
173 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
176 if (inMemoryTileId.getInt() == tileId)
180 auto alloca = createAllocaForTile(rewriter, loc,
func, tileOp);
181 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
238 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
239 const LLVMTypeConverter &typeConverter,
240 PatternBenefit benefit)
241 : ConvertToLLVMPattern(rootOpName, &typeConverter.
getContext(),
242 typeConverter, benefit) {}
245 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
246 ConversionPatternRewriter &rewriter)
const override {
247 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
253 "failed to allocate SME virtual tile to operation, tile value will go "
254 "through memory, expect degraded performance");
258 auto loc = tileOp.
getLoc();
259 auto func = tileOp->getParentOfType<FunctionOpInterface>();
260 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
265 auto zeroTileId = rewriter.getI32IntegerAttr(0);
266 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.
setTileId(zeroTileId); });
269 auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
270 auto swapInMemoryTileWithSMETileZero = [&] {
271 emitFullTileSwap(rewriter, loc, tileAlloca,
280 rewriter.setInsertionPoint(op);
282 swapInMemoryTileWithSMETileZero();
283 rewriter.setInsertionPointAfter(op);
285 swapInMemoryTileWithSMETileZero();
292 Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
293 Value tileMemory, Value sliceIndex)
const {
294 auto llvmType = getTypeConverter()->convertType(tileMemory.
getType());
296 UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory);
298 auto sliceIndexI64 = arith::IndexCastOp::create(
299 rewriter, loc, rewriter.
getI64Type(), sliceIndex);
301 static_cast<ConversionPatternRewriter &
>(rewriter), loc,
302 llvm::cast<MemRefType>(tileMemory.
getType()), descriptor.getResult(0),
303 {sliceIndexI64, zero});
308 void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
309 arm_sme::ArmSMETileType tileType, VectorType sliceType,
310 IntegerAttr tileId, Value sliceIndex)
const {
312 auto sliceIndexI32 = arith::IndexCastOp::create(
313 rewriter, loc, rewriter.
getI32Type(), sliceIndex);
315 auto predicateType = sliceType.clone(rewriter.
getI1Type());
316 auto allTruePredicate = arith::ConstantOp::create(
319 auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType);
322 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
324 auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(
325 rewriter, loc, sliceType, padVector, allTruePredicate, tileId,
328 createLoadTileSliceIntrinsic(
329 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330 allTruePredicate, slicePtr, tileId, sliceIndexI32);
333 vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca,
339 void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
340 arm_sme::ArmSMETileType tileType, VectorType sliceType,
341 IntegerAttr tileId)
const {
342 RewriterBase::InsertionGuard guard(rewriter);
348 arith::MulIOp::create(rewriter, loc, minNumElts,
349 vector::VectorScaleOp::create(rewriter, loc));
352 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
355 auto sliceIndex = forOp.getInductionVar();
356 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
361enum class RequiresSpillsAndFills { Yes, No };
367template <
typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
368 RequiresSpillsAndFills::Yes>
370 using ArmSMEOp = SourceOp;
371 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
373 static constexpr bool requiresSpillsAndFillsConversion() {
374 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
378template <
typename Pattern>
384 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
386 typename Pattern::ArmSMEOp>,
387 typename Pattern::ArmSMEOp>) {
390 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
391 Pattern::ArmSMEOp::getOperationName(), typeConverter,
398template <
typename... Patterns>
402 (addArmSMEConversionPattern<Patterns>(
patterns, typeConverter), ...);
420struct ZeroOpConversion :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
421 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
424 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter)
const override {
426 auto loc = zero.getLoc();
428 auto tileId = getTileIdOrError(zero);
436 arm_sme::ArmSMETileType tileType =
438 auto baseMaskForSize = [&] {
440 case arm_sme::ArmSMETileType::ZAB:
444 case arm_sme::ArmSMETileType::ZAH:
449 case arm_sme::ArmSMETileType::ZAS:
454 case arm_sme::ArmSMETileType::ZAD:
459 llvm_unreachable(
"bad element size");
484 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
485 arm_sme::aarch64_sme_zero::create(rewriter, loc,
486 rewriter.getI32IntegerAttr(zeroMask));
491 rewriter.setInsertionPointToStart(zero->getBlock());
492 rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
499struct LoadTileSliceConversion
500 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
501 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
504 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
505 arm_sme::LoadTileSliceOp::Adaptor adaptor,
506 ConversionPatternRewriter &rewriter)
const override {
507 auto loc = loadTileSliceOp.getLoc();
508 auto tileId = getTileIdOrError(loadTileSliceOp);
513 rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
514 adaptor.getIndices());
516 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
519 auto tileSliceI32 = arith::IndexCastUIOp::create(
520 rewriter, loc, rewriter.getI32Type(), tileSlice);
523 auto maskOp = loadTileSliceOp.getMask();
525 auto tileVectorType = loadTileSliceOp.getVectorType();
527 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
530 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
531 tileId, tileSliceI32);
535 rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
542struct StoreTileSliceConversion
543 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
544 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
547 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
548 arm_sme::StoreTileSliceOp::Adaptor adaptor,
549 ConversionPatternRewriter &rewriter)
const override {
550 auto loc = storeTileSliceOp.getLoc();
551 auto tileVectorType = storeTileSliceOp.getVectorType();
553 auto tileId = getTileIdOrError(storeTileSliceOp);
559 rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
560 adaptor.getIndices());
562 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
565 auto tileSliceI32 = arith::IndexCastUIOp::create(
566 rewriter, loc, rewriter.getI32Type(), tileSlice);
568 auto maskOp = storeTileSliceOp.getMask();
570 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
573 rewriter.replaceOp(storeTileSliceOp,
574 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
576 tileId, tileSliceI32));
583struct InsertTileSliceConversion
584 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
585 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
588 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
589 arm_sme::InsertTileSliceOp::Adaptor adaptor,
590 ConversionPatternRewriter &rewriter)
const override {
591 auto loc = insertTileSliceOp.getLoc();
592 auto tileType = insertTileSliceOp.getTileType();
594 auto tileId = getTileIdOrError(insertTileSliceOp);
598 auto tileSlice = insertTileSliceOp.getTileSliceIndex();
601 auto tileSliceI32 = arith::IndexCastUIOp::create(
602 rewriter, loc, rewriter.getI32Type(), tileSlice);
605 auto one = arith::ConstantOp::create(
606 rewriter, loc, rewriter.getI1Type(),
607 rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
608 auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
611 vector::BroadcastOp::create(rewriter, loc, predTy, one);
614 switch (insertTileSliceOp.getLayout()) {
615 case arm_sme::TileSliceLayout::Horizontal:
616 arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId,
617 tileSliceI32, allActiveMask,
618 insertTileSliceOp.getVector());
620 case arm_sme::TileSliceLayout::Vertical:
621 arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId,
622 tileSliceI32, allActiveMask,
623 insertTileSliceOp.getVector());
629 rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
636struct ExtractTileSliceConversion
637 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
638 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
641 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
642 ConversionPatternRewriter &rewriter)
const override {
643 auto loc = extractTileSlice.getLoc();
644 auto sliceType = extractTileSlice.getSliceType();
645 auto sliceIndex = extractTileSlice.getTileSliceIndex();
647 auto tileId = getTileIdOrError(extractTileSlice);
652 auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
653 auto allTruePredicate = arith::ConstantOp::create(
657 auto zeroVector = arith::ConstantOp::create(
658 rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType));
661 auto sliceIndexI32 = arith::IndexCastOp::create(
662 rewriter, loc, rewriter.getI32Type(), sliceIndex);
665 switch (extractTileSlice.getLayout()) {
666 case arm_sme::TileSliceLayout::Horizontal:
667 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
668 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
671 case arm_sme::TileSliceLayout::Vertical:
672 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
673 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
696struct OuterProductOpConversion
697 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
698 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
701 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
702 arm_sme::OuterProductOp::Adaptor adaptor,
703 ConversionPatternRewriter &rewriter)
const override {
704 auto tileId = getTileIdOrError(outerProductOp);
708 auto isSupportedType = [](VectorType vectorType) {
722 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
725 auto elementType = vectorType.getElementType();
727 if (!elementType.isF16() && !elementType.isBF16() &&
728 !elementType.isF32() && !elementType.isF64())
732 vectorType.getElementTypeBitWidth();
733 return vectorType.getShape() ==
734 ArrayRef<int64_t>({minNumElts, minNumElts});
738 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
739 return outerProductOp.emitError(
"unsupported kind");
741 auto resultVectorType = outerProductOp.getResultType();
742 if (!isSupportedType(resultVectorType))
743 return outerProductOp.emitError(
"unsupported type");
745 auto loc = outerProductOp.getLoc();
747 Value acc = outerProductOp.getAcc();
750 auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType);
751 zero.setTileId(tileId);
755 Value lhsMask = outerProductOp.getLhsMask();
756 Value rhsMask = outerProductOp.getRhsMask();
758 if (!lhsMask || !rhsMask) {
760 outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
761 Value allActiveMask = arith::ConstantOp::create(
763 lhsMask = allActiveMask;
764 rhsMask = allActiveMask;
768 arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask,
769 outerProductOp.getLhs(),
770 outerProductOp.getRhs());
774 rewriter.replaceOp(outerProductOp, acc);
781template <
class OuterProductW
ideningOp,
class OuterProductW
ideningIntrOp>
782struct OuterProductWideningOpConversion
783 :
public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
784 using ConvertArmSMEOpToLLVMPattern<
785 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
788 matchAndRewrite(OuterProductWideningOp op,
789 typename OuterProductWideningOp::Adaptor adaptor,
790 ConversionPatternRewriter &rewriter)
const override {
791 auto tileId = getTileIdOrError(op);
795 auto loc = op.getLoc();
796 Value acc = op.getAcc();
799 auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType());
800 zero.setTileId(tileId);
804 Value lhsMask = op.getLhsMask();
805 Value rhsMask = op.getRhsMask();
806 if (!lhsMask || !rhsMask) {
807 auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
808 Value allActiveMask = arith::ConstantOp::create(
810 lhsMask = allActiveMask;
811 rhsMask = allActiveMask;
814 OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask,
815 adaptor.getLhs(), adaptor.getRhs());
819 rewriter.replaceOp(op, acc);
838struct StreamingVLOpConversion
839 :
public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
840 RequiresSpillsAndFills::No> {
841 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
844 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
845 arm_sme::StreamingVLOp::Adaptor adaptor,
846 ConversionPatternRewriter &rewriter)
const override {
847 auto loc = streamingVlOp.getLoc();
848 auto i64Type = rewriter.getI64Type();
849 auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
850 auto cntsdIdx = arith::IndexCastOp::create(rewriter, loc,
851 rewriter.getIndexType(), cntsd);
855 rewriter.replaceOpWithNewOp<arith::MulIOp>(streamingVlOp, cntsdIdx, scale);
862static void mergeConsecutiveTileZerosInBlock(
Block *block) {
863 uint32_t mergedZeroMask = 0;
865 auto replaceMergedZeroOps = [&] {
866 auto cleanup = llvm::make_scope_exit([&] {
868 zeroOpsToMerge.clear();
870 if (zeroOpsToMerge.size() <= 1)
873 arm_sme::aarch64_sme_zero::create(
874 rewriter, zeroOpsToMerge.front().getLoc(),
875 rewriter.getI32IntegerAttr(mergedZeroMask));
876 for (
auto zeroOp : zeroOpsToMerge)
877 rewriter.eraseOp(zeroOp);
880 if (
auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
881 mergedZeroMask |= zeroOp.getTileMask();
882 zeroOpsToMerge.push_back(zeroOp);
884 replaceMergedZeroOps();
887 replaceMergedZeroOps();
894struct ConvertArmSMEToLLVMPass
914 function->walk(mergeConsecutiveTileZerosInBlock);
920 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
923 auto isSMETileType = [](
Type type) {
928 op->
emitOpError(
"unexpected operation with SME tile type after "
929 "conversion to LLVM");
939 target.addIllegalDialect<arm_sme::ArmSMEDialect>();
941 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
942 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
943 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
944 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
945 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
946 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
947 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
948 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
949 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
950 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
951 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
952 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
953 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
954 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
955 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
956 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
957 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
958 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
959 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
960 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
961 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsd>();
962 target.addLegalDialect<arith::ArithDialect,
964 vector::VectorDialect, scf::SCFDialect,
965 memref::MemRefDialect>();
969 target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
970 UnrealizedConversionCastOp>();
975 converter.addConversion([&](VectorType type) -> std::optional<Type> {
983 addArmSMEConversionPatterns<
984 LoadTileSliceConversion, ExtractTileSliceConversion,
985 InsertTileSliceConversion, StoreTileSliceConversion,
986 StreamingVLOpConversion, OuterProductOpConversion,
987 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
988 arm_sme::aarch64_sme_mopa_wide>,
989 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
990 arm_sme::aarch64_sme_mops_wide>,
991 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
992 arm_sme::aarch64_sme_smopa_za32>,
993 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
994 arm_sme::aarch64_sme_smops_za32>,
995 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
996 arm_sme::aarch64_sme_umopa_za32>,
997 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
998 arm_sme::aarch64_sme_umops_za32>,
999 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1000 arm_sme::aarch64_sme_smopa_wide>,
1001 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1002 arm_sme::aarch64_sme_smops_wide>,
1003 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1004 arm_sme::aarch64_sme_umopa_wide>,
1005 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1006 arm_sme::aarch64_sme_umops_wide>,
1007 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1008 arm_sme::aarch64_sme_sumopa_wide>,
1009 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1010 arm_sme::aarch64_sme_sumops_wide>,
1011 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1012 arm_sme::aarch64_sme_usmopa_wide>,
1013 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1014 arm_sme::aarch64_sme_usmops_wide>,
1015 ZeroOpConversion>(
patterns, converter);
1018std::unique_ptr<Pass>
1020 return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
Block represents an ordered list of Operations.
IntegerAttr getI32IntegerAttr(int32_t value)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for operation conversions targeting the LLVM IR dialect.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
FunctionOpInterface getOperation()
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
Operation is the basic unit of execution within MLIR.
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
VectorType getTileType()
Returns the VectorType of the tile used by this operation.
mlir::IntegerAttr getTileId()
Returns the tile ID assigned to this operation.
void setTileId(mlir::IntegerAttr tileId)
Sets the tile ID for this operation.
::mlir::Pass::Option< bool > dumpTileLiveRanges
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
unsigned getSizeInBytes(TypeSize type)
Return the size represented by arm_sme::TypeSize in bytes.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
constexpr unsigned MinStreamingVectorLengthInBits
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.