MLIR  19.0.0git
TileAllocation.cpp
Go to the documentation of this file.
1 //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass allocates SME tiles at the 'func.func' op level for ArmSME
10 // operations. It does this using a 16-bit tile mask that has a bit for each
11 // 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
12 //
13 // The 128-bit tiles overlap with other element tiles as follows (see section
14 // B2.3.2 of SME spec [1]):
15 //
16 // Tile Overlaps
17 // ---------------------------------------------------------------------------
18 // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
19 // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
20 // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
21 // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
22 // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
23 // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
24 // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
25 // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
26 // ZA0.D ZA0.Q, ZA8.Q
27 // ZA1.D ZA1.Q, ZA9.Q
28 // ZA2.D ZA2.Q, ZA10.Q
29 // ZA3.D ZA3.Q, ZA11.Q
30 // ZA4.D ZA4.Q, ZA12.Q
31 // ZA5.D ZA5.Q, ZA13.Q
32 // ZA6.D ZA6.Q, ZA14.Q
33 // ZA7.D ZA7.Q, ZA15.Q
34 //
35 // The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use'
36 // that is initalized during the first tile allocation within a function and
37 // updated on each subsequent allocation.
38 //
39 // [1] https://developer.arm.com/documentation/ddi0616/aa
40 //
41 //===----------------------------------------------------------------------===//
42 
48 #include "llvm/ADT/TypeSwitch.h"
49 
50 #define DEBUG_TYPE "allocate-arm-sme-tiles"
51 
52 namespace mlir {
53 namespace arm_sme {
54 #define GEN_PASS_DEF_TILEALLOCATION
55 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
56 } // namespace arm_sme
57 } // namespace mlir
58 
59 using namespace mlir;
60 using namespace mlir::arm_sme;
61 
62 namespace {
63 
64 static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use");
65 static constexpr StringLiteral
66  kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id");
67 
68 enum class TileMask : unsigned {
69  // clang-format off
70  kZA0B = 0xffff, // 1111 1111 1111 1111
71 
72  kZA0H = 0xaaaa, // 1010 1010 1010 1010
73  kZA1H = 0x5555, // 0101 0101 0101 0101
74 
75  kZA0S = 0x8888, // 1000 1000 1000 1000
76  kZA1S = 0x4444, // 0100 0100 0100 0100
77  kZA2S = 0x2222, // 0010 0010 0010 0010
78  kZA3S = 0x1111, // 0001 0001 0001 0001
79 
80  kZA0D = 0x8080, // 1000 0000 1000 0000
81  kZA1D = 0x4040, // 0100 0000 0100 0000
82  kZA2D = 0x2020, // 0010 0000 0010 0000
83  kZA3D = 0x1010, // 0001 0000 0001 0000
84  kZA4D = 0x808, // 0000 1000 0000 1000
85  kZA5D = 0x404, // 0000 0100 0000 0100
86  kZA6D = 0x202, // 0000 0010 0000 0010
87  kZA7D = 0x101, // 0000 0001 0000 0001
88 
89  kZA0Q = 0x8000, // 1000 0000 0000 0000
90  kZA1Q = 0x4000, // 0100 0000 0000 0000
91  kZA2Q = 0x2000, // 0010 0000 0000 0000
92  kZA3Q = 0x1000, // 0001 0000 0000 0000
93  kZA4Q = 0x800, // 0000 1000 0000 0000
94  kZA5Q = 0x400, // 0000 0100 0000 0000
95  kZA6Q = 0x200, // 0000 0010 0000 0000
96  kZA7Q = 0x100, // 0000 0001 0000 0000
97  kZA8Q = 0x80, // 0000 0000 1000 0000
98  kZA9Q = 0x40, // 0000 0000 0100 0000
99  kZA10Q = 0x20, // 0000 0000 0010 0000
100  kZA11Q = 0x10, // 0000 0000 0001 0000
101  kZA12Q = 0x8, // 0000 0000 0000 1000
102  kZA13Q = 0x4, // 0000 0000 0000 0100
103  kZA14Q = 0x2, // 0000 0000 0000 0010
104  kZA15Q = 0x1, // 0000 0000 0000 0001
105 
106  kNone = 0x0, // 0000 0000 0000 0000
107  // clang-format on
108 
109  LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
110 };
111 
112 /// Returns the set of masks relevant for the given type.
113 static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
114  static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
115  static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
116  static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
117  TileMask::kZA2S, TileMask::kZA3S};
118  static constexpr std::array ZA_D_MASKS = {
119  TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
120  TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
121  static constexpr std::array ZA_Q_MASKS = {
122  TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
123  TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
124  TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
125  TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
126  switch (type) {
127  case ArmSMETileType::ZAB:
128  return ZA_B_MASKS;
129  case ArmSMETileType::ZAH:
130  return ZA_H_MASKS;
131  case ArmSMETileType::ZAS:
132  return ZA_S_MASKS;
133  case ArmSMETileType::ZAD:
134  return ZA_D_MASKS;
135  case ArmSMETileType::ZAQ:
136  return ZA_Q_MASKS;
137  }
138 }
139 
140 /// Allocates and returns a tile ID. Returns an error if there are no tiles
141 /// left.
142 static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
143  TileMask &tilesInUse) {
144  auto masks = getMasks(tileType);
145  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
146  if ((tilesInUse & tileMask) == TileMask::kNone) {
147  tilesInUse |= tileMask;
148  return tileId;
149  }
150  }
151  return failure();
152 }
153 
154 /// Collects transitive uses of a root value through control flow. This can
155 /// handle basic SCF constructs, along with control flow (br and cond_br).
156 /// Simple loops work at the SCF level, while more complex control flow can be
157 /// dealt with after lowering to CF. This is used to implement basic tile
158 /// allocation.
159 static void findDependantOps(Value rootValue,
160  SetVector<Operation *> &dependantOps) {
161  auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
162  for (auto [idx, value] : llvm::enumerate(inputValues)) {
163  if (value == rootValue)
164  findDependantOps(exitValues[idx], dependantOps);
165  }
166  };
167  for (Operation *user : rootValue.getUsers()) {
168  if (dependantOps.contains(user))
169  continue;
170  dependantOps.insert(user);
172  .Case<cf::BranchOp>([&](auto branchOp) {
173  // (CF) Follow branch.
174  traverseCorrespondingValues(branchOp.getDestOperands(),
175  branchOp.getDest()->getArguments());
176  })
177  .Case<cf::CondBranchOp>([&](auto condBranchOp) {
178  // (CF) Follow true branch.
179  traverseCorrespondingValues(
180  condBranchOp.getTrueOperands(),
181  condBranchOp.getTrueDest()->getArguments());
182  // (CF) Follow false branch.
183  traverseCorrespondingValues(
184  condBranchOp.getFalseOperands(),
185  condBranchOp.getFalseDest()->getArguments());
186  })
187  .Case<LoopLikeOpInterface>([&](auto loopOp) {
188  // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
189  traverseCorrespondingValues(loopOp.getInits(),
190  loopOp.getRegionIterArgs());
191  })
192  .Case<scf::YieldOp>([&](auto yieldOp) {
193  // (SCF) Follow yields of (basic) control flow (e.g. for loops).
194  auto parent = user->getParentOp();
195  traverseCorrespondingValues(user->getOperands(),
196  parent->getResults());
197  })
198  .Default([&](auto) {
199  // Otherwise, assume users of _any_ result are dependant.
200  for (Value result : user->getResults())
201  findDependantOps(result, dependantOps);
202  });
203  }
204 }
205 struct AssignTileIDsPattern
206  : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
208  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
209  PatternRewriter &rewriter) const override {
210  if (tileOp.getTileId())
211  return failure();
212 
213  auto func = tileOp->getParentOfType<FunctionOpInterface>();
214  auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
215  if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
216  func->getDiscardableAttr(name)))
217  return unsigned(attr.getInt());
218  return defaultVal;
219  };
220  auto setDiscardableIntAttr = [&](StringRef name, auto value) {
221  rewriter.modifyOpInPlace(tileOp, [&] {
222  func->setDiscardableAttr(name,
223  rewriter.getI32IntegerAttr((unsigned)value));
224  });
225  };
226 
227  std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
228  if (!tileType)
229  return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
230 
231  TileMask tilesInUse =
232  static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
233  auto tileId = allocateTileId(*tileType, tilesInUse);
234  bool tileIsInMemory = failed(tileId);
235  if (tileIsInMemory) {
236  // If we could not find a real tile ID, use an in-memory tile ID (ID >=
237  // 16). A later pass will insert the necessary spills and reloads.
238  tileId =
239  getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
240  tileOp->emitWarning(
241  "failed to allocate SME virtual tile to operation, all tile "
242  "operations will go through memory, expect degraded performance");
243  }
244 
245  // Set all operations dependent on `tileOp` to use the same tile ID.
246  // This is a naive tile allocation scheme, but works for common cases. For
247  // example, as this only allocates tile IDs to existing ops, it can't solve
248  // cases like this (%tileA and %tileB come from different root operations):
249  //
250  // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
251  // scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
252  // } else {
253  // scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
254  // }
255  //
256  // This case would require allocating a new tile for the result of the
257  // scf.if, and moving the contents of %tileA or %tileB to result tile (based
258  // on the %some_cond).
259  // Find all the ops that (transitively) depend on this tile.
260  SetVector<Operation *> dependantOps;
261  findDependantOps(tileOp->getResult(0), dependantOps);
262  auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
263  for (auto *op : dependantOps) {
264  if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
265  auto currentTileId = dependantTileOp.getTileId();
266  if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
267  return dependantTileOp.emitOpError(
268  "already assigned different SME virtual tile!");
269  }
270  }
271 
272  // Rewrite IR.
273  if (!tileIsInMemory)
274  setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
275  else
276  setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
277  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
278  for (auto *op : dependantOps) {
279  if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
280  rewriter.modifyOpInPlace(
281  dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
282  }
283  }
284 
285  return success();
286  }
287 };
288 
289 struct TileAllocationPass
290  : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
291  void runOnOperation() override {
292  RewritePatternSet patterns(&getContext());
293  patterns.add<AssignTileIDsPattern>(patterns.getContext());
294  GreedyRewriteConfig config;
295  // Setting useTopDownTraversal ensures tiles are allocated in program
296  // order.
297  config.useTopDownTraversal = true;
299  getOperation(), std::move(patterns), config))) {
300  signalPassFailure();
301  }
302  }
303 };
304 } // namespace
305 
307  return std::make_unique<TileAllocationPass>();
308 }
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
static constexpr unsigned kInMemoryTileIdBase
Definition: ArmSME.h:28
std::unique_ptr< Pass > createTileAllocationPass()
Pass that allocates tile IDs to ArmSME operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374