MLIR 22.0.0git
EnableArmStreaming.cpp
Go to the documentation of this file.
1//===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE
10// (SSVE) mode [1][2] by adding either of the following attributes to
11// 'func.func' ops:
12//
13// * 'arm_streaming' (default)
14// * 'arm_locally_streaming'
15//
16// It can also optionally enable the ZA storage array.
17//
18// Streaming-mode is part of the interface (ABI) for functions with the
19// first attribute and it's the responsibility of the caller to manage
20// PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM
21// backend will emit 'smstart sm' / 'smstop sm' [4] around calls to
22// streaming functions.
23//
24// In locally streaming functions PSTATE.SM is kept internal and managed by
25// the callee on entry/exit. The LLVM backend will emit 'smstart sm' /
26// 'smstop sm' in the prologue / epilogue for functions with this
27// attribute.
28//
29// [1] https://developer.arm.com/documentation/ddi0616/aa
30// [2] https://llvm.org/docs/AArch64SME.html
31// [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces
32// [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate--
33//
34//===----------------------------------------------------------------------===//
35
38#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
39
41
42#define DEBUG_TYPE "enable-arm-streaming"
43
44namespace mlir {
45namespace arm_sme {
46#define GEN_PASS_DEF_ENABLEARMSTREAMING
47#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
48} // namespace arm_sme
49} // namespace mlir
50
51using namespace mlir;
52using namespace mlir::arm_sme;
53namespace {
54
55constexpr StringLiteral
56 kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
57
58template <typename... Ops>
59constexpr auto opList() {
60 return std::array{TypeID::get<Ops>()...};
61}
62
63bool isScalableVector(Type type) {
64 if (auto vectorType = dyn_cast<VectorType>(type))
65 return vectorType.isScalable();
66 return false;
67}
68
69struct EnableArmStreamingPass
70 : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
71 EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72 bool ifRequiredByOps, bool ifScalableAndSupported) {
73 this->streamingMode = streamingMode;
74 this->zaMode = zaMode;
75 this->ifRequiredByOps = ifRequiredByOps;
76 this->ifScalableAndSupported = ifScalableAndSupported;
77 }
78 void runOnOperation() override {
79 auto function = getOperation();
80
81 if (ifRequiredByOps && ifScalableAndSupported) {
82 function->emitOpError(
83 "enable-arm-streaming: `if-required-by-ops` and "
84 "`if-scalable-and-supported` are mutually exclusive");
85 return signalPassFailure();
86 }
87
88 if (ifRequiredByOps) {
89 bool foundTileOp = false;
90 function.walk([&](Operation *op) {
91 if (llvm::isa<ArmSMETileOpInterface>(op)) {
92 foundTileOp = true;
93 return WalkResult::interrupt();
94 }
95 return WalkResult::advance();
96 });
97 if (!foundTileOp)
98 return;
99 }
100
101 if (ifScalableAndSupported) {
102 // FIXME: This should be based on target information (i.e., the presence
103 // of FEAT_SME_FA64). This currently errs on the side of caution. If
104 // possible gathers/scatters should be lowered regular vector loads/stores
105 // before invoking this pass.
106 auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107 bool isCompatibleScalableFunction = false;
108 function.walk([&](Operation *op) {
109 if (llvm::is_contained(disallowedOperations,
110 op->getName().getTypeID())) {
111 isCompatibleScalableFunction = false;
112 return WalkResult::interrupt();
113 }
114 if (!isCompatibleScalableFunction &&
115 (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
116 llvm::any_of(op->getResultTypes(), isScalableVector))) {
117 isCompatibleScalableFunction = true;
118 }
119 return WalkResult::advance();
120 });
121 if (!isCompatibleScalableFunction)
122 return;
123 }
124
125 if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126 streamingMode == ArmStreamingMode::Disabled)
127 return;
128
129 auto unitAttr = UnitAttr::get(&getContext());
130
131 function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
132
133 // The pass currently only supports enabling ZA when in streaming-mode, but
134 // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
135 // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
136 // supporting this later.
137 if (zaMode != ArmZaMode::Disabled)
138 function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
139 }
140};
141} // namespace
142
144 const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
145 bool ifRequiredByOps, bool ifScalableAndSupported) {
146 return std::make_unique<EnableArmStreamingPass>(
147 streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
148}
b getContext())
TypeID getTypeID() const
Return the unique identifier of the derived Op class, or null if not registered.
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
std::unique_ptr< Pass > createEnableArmStreamingPass(const ArmStreamingMode=ArmStreamingMode::Streaming, const ArmZaMode=ArmZaMode::Disabled, bool ifRequiredByOps=false, bool ifContainsScalableVectors=false)
Pass to enable Armv9 Streaming SVE mode.
Include the generated interface declarations.