MLIR  20.0.0git
EnableArmStreaming.cpp
Go to the documentation of this file.
1 //===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE
10 // (SSVE) mode [1][2] by adding either of the following attributes to
11 // 'func.func' ops:
12 //
13 // * 'arm_streaming' (default)
14 // * 'arm_locally_streaming'
15 //
16 // It can also optionally enable the ZA storage array.
17 //
18 // Streaming-mode is part of the interface (ABI) for functions with the
19 // first attribute and it's the responsibility of the caller to manage
20 // PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM
21 // backend will emit 'smstart sm' / 'smstop sm' [4] around calls to
22 // streaming functions.
23 //
24 // In locally streaming functions PSTATE.SM is kept internal and managed by
25 // the callee on entry/exit. The LLVM backend will emit 'smstart sm' /
26 // 'smstop sm' in the prologue / epilogue for functions with this
27 // attribute.
28 //
29 // [1] https://developer.arm.com/documentation/ddi0616/aa
30 // [2] https://llvm.org/docs/AArch64SME.html
31 // [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces
32 // [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate--
33 //
34 //===----------------------------------------------------------------------===//
35 
38 #include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
39 
41 
42 #define DEBUG_TYPE "enable-arm-streaming"
43 
44 namespace mlir {
45 namespace arm_sme {
46 #define GEN_PASS_DEF_ENABLEARMSTREAMING
47 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
48 } // namespace arm_sme
49 } // namespace mlir
50 
51 using namespace mlir;
52 using namespace mlir::arm_sme;
53 namespace {
54 
55 constexpr StringLiteral
56  kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
57 
58 template <typename... Ops>
59 constexpr auto opList() {
60  return std::array{TypeID::get<Ops>()...};
61 }
62 
63 bool isScalableVector(Type type) {
64  if (auto vectorType = dyn_cast<VectorType>(type))
65  return vectorType.isScalable();
66  return false;
67 }
68 
69 struct EnableArmStreamingPass
70  : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
71  EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72  bool ifRequiredByOps, bool ifScalableAndSupported) {
73  this->streamingMode = streamingMode;
74  this->zaMode = zaMode;
75  this->ifRequiredByOps = ifRequiredByOps;
76  this->ifScalableAndSupported = ifScalableAndSupported;
77  }
78  void runOnOperation() override {
79  auto function = getOperation();
80 
81  if (ifRequiredByOps && ifScalableAndSupported) {
82  function->emitOpError(
83  "enable-arm-streaming: `if-required-by-ops` and "
84  "`if-scalable-and-supported` are mutually exclusive");
85  return signalPassFailure();
86  }
87 
88  if (ifRequiredByOps) {
89  bool foundTileOp = false;
90  function.walk([&](Operation *op) {
91  if (llvm::isa<ArmSMETileOpInterface>(op)) {
92  foundTileOp = true;
93  return WalkResult::interrupt();
94  }
95  return WalkResult::advance();
96  });
97  if (!foundTileOp)
98  return;
99  }
100 
101  if (ifScalableAndSupported) {
102  // FIXME: This should be based on target information (i.e., the presence
103  // of FEAT_SME_FA64). This currently errs on the side of caution. If
104  // possible gathers/scatters should be lowered regular vector loads/stores
105  // before invoking this pass.
106  auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107  bool isCompatibleScalableFunction = false;
108  function.walk([&](Operation *op) {
109  if (llvm::is_contained(disallowedOperations,
110  op->getName().getTypeID())) {
111  isCompatibleScalableFunction = false;
112  return WalkResult::interrupt();
113  }
114  if (!isCompatibleScalableFunction &&
115  (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
116  llvm::any_of(op->getResultTypes(), isScalableVector))) {
117  isCompatibleScalableFunction = true;
118  }
119  return WalkResult::advance();
120  });
121  if (!isCompatibleScalableFunction)
122  return;
123  }
124 
125  if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126  streamingMode == ArmStreamingMode::Disabled)
127  return;
128 
129  auto unitAttr = UnitAttr::get(&getContext());
130 
131  function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
132 
133  // The pass currently only supports enabling ZA when in streaming-mode, but
134  // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
135  // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
136  // supporting this later.
137  if (zaMode != ArmZaMode::Disabled)
138  function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
139  }
140 };
141 } // namespace
142 
144  const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
145  bool ifRequiredByOps, bool ifScalableAndSupported) {
146  return std::make_unique<EnableArmStreamingPass>(
147  streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
148 }
static MLIRContext * getContext(OpFoldResult val)
TypeID getTypeID() const
Return the unique identifier of the derived Op class, or null if not registered.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
std::unique_ptr< Pass > createEnableArmStreamingPass(const ArmStreamingMode=ArmStreamingMode::Streaming, const ArmZaMode=ArmZaMode::Disabled, bool ifRequiredByOps=false, bool ifContainsScalableVectors=false)
Pass to enable Armv9 Streaming SVE mode.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...