MLIR  22.0.0git
TileAllocation.cpp
Go to the documentation of this file.
1 //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transform allocates SME tiles at the 'func.func' op level for ArmSME
10 // operations. It roughly implements a linear scan register allocator, similar
11 // to the one outlined in [1], but with simplifications and assumptions made for
12 // our use case. Note that this is a greedy allocator (so it may not always find
13 // the most optimal allocation of tiles).
14 //
15 // The allocator operates at the CF dialect level. It is the responsibility of
16 // users to ensure the IR has been lowered to CF before invoking the tile
17 // allocator.
18 //
19 // The 128-bit tiles overlap with other element tiles as follows (see section
20 // B2.3.2 of SME spec [2]):
21 //
22 // Tile Overlaps
23 // ---------------------------------------------------------------------------
24 // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
25 // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
26 // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
27 // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
28 // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
29 // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
30 // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
31 // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
32 // ZA0.D ZA0.Q, ZA8.Q
33 // ZA1.D ZA1.Q, ZA9.Q
34 // ZA2.D ZA2.Q, ZA10.Q
35 // ZA3.D ZA3.Q, ZA11.Q
36 // ZA4.D ZA4.Q, ZA12.Q
37 // ZA5.D ZA5.Q, ZA13.Q
38 // ZA6.D ZA6.Q, ZA14.Q
39 // ZA7.D ZA7.Q, ZA15.Q
40 //
41 // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42 // Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43 // https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44 // [2] https://developer.arm.com/documentation/ddi0616/aa
45 //
46 //===----------------------------------------------------------------------===//
47 
48 #include "mlir/Analysis/Liveness.h"
55 #include "llvm/ADT/IntervalMap.h"
56 #include "llvm/ADT/TypeSwitch.h"
57 
58 namespace mlir::arm_sme {
59 #define GEN_PASS_DEF_TESTTILEALLOCATION
60 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
61 } // namespace mlir::arm_sme
62 
63 using namespace mlir;
64 using namespace mlir::arm_sme;
65 
66 namespace {
67 
68 enum class TileMask : unsigned {
69  // clang-format off
70  kZA0B = 0xffff, // 1111 1111 1111 1111
71 
72  kZA0H = 0xaaaa, // 1010 1010 1010 1010
73  kZA1H = 0x5555, // 0101 0101 0101 0101
74 
75  kZA0S = 0x8888, // 1000 1000 1000 1000
76  kZA1S = 0x4444, // 0100 0100 0100 0100
77  kZA2S = 0x2222, // 0010 0010 0010 0010
78  kZA3S = 0x1111, // 0001 0001 0001 0001
79 
80  kZA0D = 0x8080, // 1000 0000 1000 0000
81  kZA1D = 0x4040, // 0100 0000 0100 0000
82  kZA2D = 0x2020, // 0010 0000 0010 0000
83  kZA3D = 0x1010, // 0001 0000 0001 0000
84  kZA4D = 0x808, // 0000 1000 0000 1000
85  kZA5D = 0x404, // 0000 0100 0000 0100
86  kZA6D = 0x202, // 0000 0010 0000 0010
87  kZA7D = 0x101, // 0000 0001 0000 0001
88 
89  kZA0Q = 0x8000, // 1000 0000 0000 0000
90  kZA1Q = 0x4000, // 0100 0000 0000 0000
91  kZA2Q = 0x2000, // 0010 0000 0000 0000
92  kZA3Q = 0x1000, // 0001 0000 0000 0000
93  kZA4Q = 0x800, // 0000 1000 0000 0000
94  kZA5Q = 0x400, // 0000 0100 0000 0000
95  kZA6Q = 0x200, // 0000 0010 0000 0000
96  kZA7Q = 0x100, // 0000 0001 0000 0000
97  kZA8Q = 0x80, // 0000 0000 1000 0000
98  kZA9Q = 0x40, // 0000 0000 0100 0000
99  kZA10Q = 0x20, // 0000 0000 0010 0000
100  kZA11Q = 0x10, // 0000 0000 0001 0000
101  kZA12Q = 0x8, // 0000 0000 0000 1000
102  kZA13Q = 0x4, // 0000 0000 0000 0100
103  kZA14Q = 0x2, // 0000 0000 0000 0010
104  kZA15Q = 0x1, // 0000 0000 0000 0001
105 
106  kNone = 0x0, // 0000 0000 0000 0000
107  // clang-format on
108 
110 };
111 
112 /// Returns the set of masks relevant for the given type.
113 static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
114  static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
115  static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
116  static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
117  TileMask::kZA2S, TileMask::kZA3S};
118  static constexpr std::array ZA_D_MASKS = {
119  TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
120  TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
121  static constexpr std::array ZA_Q_MASKS = {
122  TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
123  TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
124  TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
125  TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
126  switch (type) {
127  case ArmSMETileType::ZAB:
128  return ZA_B_MASKS;
129  case ArmSMETileType::ZAH:
130  return ZA_H_MASKS;
131  case ArmSMETileType::ZAS:
132  return ZA_S_MASKS;
133  case ArmSMETileType::ZAD:
134  return ZA_D_MASKS;
135  case ArmSMETileType::ZAQ:
136  return ZA_Q_MASKS;
137  }
138  llvm_unreachable("unknown type in getMasks");
139 }
140 
141 class TileAllocator {
142 public:
143  /// Allocates and returns a tile ID. Fails if there are no tiles left.
144  FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
145  auto masks = getMasks(tileType);
146  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
147  if ((tilesInUse & tileMask) == TileMask::kNone) {
148  tilesInUse |= tileMask;
149  return tileId;
150  }
151  }
152  return failure();
153  }
154 
155  /// Acquires a specific tile ID. Asserts the tile is initially free.
156  void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
157  TileMask tileMask = getMasks(tileType)[tileId];
158  assert((tilesInUse & tileMask) == TileMask::kNone &&
159  "cannot acquire allocated tile!");
160  tilesInUse |= tileMask;
161  }
162 
163  /// Releases a previously allocated tile ID.
164  void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
165  TileMask tileMask = getMasks(tileType)[tileId];
166  assert((tilesInUse & tileMask) == tileMask &&
167  "cannot release unallocated tile!");
168  tilesInUse ^= tileMask;
169  }
170 
171  /// Allocates an in-memory tile ID.
172  unsigned allocateInMemoryTileId() {
173  // Note: We never release in-memory tile IDs. We could, which may allow
174  // reusing an allocation, but as we _never_ want to spill an SME tile this
175  // is not optimized.
176  return nextInMemoryTileId++;
177  }
178 
179 private:
180  TileMask tilesInUse = TileMask::kNone;
181  unsigned nextInMemoryTileId = kInMemoryTileIdBase;
182 };
183 
184 /// Add new intermediate blocks for the true and false destinations of
185 /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
186 /// overlaps due to copies at branches.
187 ///
188 /// BEFORE:
189 /// ```mlir
190 /// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
191 /// ```
192 ///
193 /// AFTER:
194 /// ```mlir
195 /// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
196 /// ^bb1_copy:
197 /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
198 /// ^bb2_copy:
199 /// cf.br ^bb2
200 /// ```
201 void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
203  function.walk([&](cf::CondBranchOp condBranch) {
204  if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
205  return isValidSMETileVectorType(value.getType());
206  })) {
207  worklist.push_back(condBranch);
208  }
209  });
210 
211  auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
212  rewriter.setInsertionPointToEnd(source);
213  cf::BranchOp::create(rewriter, loc, dest, args);
214  };
215 
216  for (auto condBranch : worklist) {
217  auto loc = condBranch.getLoc();
218  Block *block = condBranch->getBlock();
219  auto newTrueBranch = rewriter.splitBlock(block, block->end());
220  auto newFalseBranch = rewriter.splitBlock(block, block->end());
221  insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
222  condBranch.getTrueDestOperands());
223  insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
224  condBranch.getFalseDestOperands());
225  rewriter.modifyOpInPlace(condBranch, [&] {
226  condBranch.getFalseDestOperandsMutable().clear();
227  condBranch.getTrueDestOperandsMutable().clear();
228  condBranch.setSuccessor(newTrueBranch, 0);
229  condBranch.setSuccessor(newFalseBranch, 1);
230  });
231  }
232 }
233 
234 /// Inserts tile copies at `cf.br` operations.
235 ///
236 /// BEFORE:
237 /// ```mlir
238 /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
239 /// ```
240 ///
241 /// AFTER:
242 /// ```mlir
243 /// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
244 /// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
245 /// ```
246 void insertCopiesAtBranches(IRRewriter &rewriter,
247  FunctionOpInterface function) {
248  for (Block &block : function.getBlocks()) {
249  Operation *terminator = block.getTerminator();
250  if (!isa<cf::BranchOp>(terminator))
251  continue;
252  rewriter.setInsertionPoint(terminator);
253  for (OpOperand &operand : terminator->getOpOperands()) {
254  if (isValidSMETileVectorType(operand.get().getType())) {
255  auto copy =
256  CopyTileOp::create(rewriter, terminator->getLoc(), operand.get());
257  rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
258  }
259  }
260  }
261 }
262 
263 /// Prepares the IR for tile allocation. It does this by first 'splitting'
264 /// conditional branches (see `splitCondBranches`), then inserting tile copies
265 /// at branch operations. The conditional branches are split to prevent the
266 /// copies needed for them overlapping between the true and false paths of the
267 /// branch (see `tile-allocation-copies.mlir` and
268 /// `tile-allocation-liveness.mlir` for examples). The copies break up live
269 /// ranges and ensure when moving out of SSA the semantics of the program are
270 /// preserved.
271 void preprocessForTileAllocation(IRRewriter &rewriter,
272  FunctionOpInterface function) {
273  splitCondBranches(rewriter, function);
274  insertCopiesAtBranches(rewriter, function);
275 }
276 
277 /// A live range for a (collection of) tile values. A live range is built up of
278 /// non-overlapping intervals [start, end) which represent parts of the program
279 /// where a value in the range needs to be live (i.e. in an SME virtual tile).
280 /// Note that as the intervals are non-overlapping all values within a live
281 /// range can be allocated to the same SME virtual tile.
282 struct LiveRange {
283  using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
284  llvm::IntervalMapHalfOpenInfo<unsigned>>;
285  using Allocator = RangeSet::Allocator;
286  // Dummy value for the IntervalMap. Only the keys matter (the intervals).
287  static constexpr uint8_t kValidLiveRange = 0xff;
288 
289  LiveRange(Allocator &allocator)
290  : ranges(std::make_unique<RangeSet>(allocator)) {}
291 
292  /// Returns true if this range overlaps with `otherRange`.
293  bool overlaps(LiveRange const &otherRange) const {
294  return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
295  *otherRange.ranges)
296  .valid();
297  }
298 
299  /// Returns true if this range is active at `point` in the program.
300  bool overlaps(uint64_t point) const {
301  return ranges->lookup(point) == kValidLiveRange;
302  }
303 
304  /// Unions this live range with `otherRange`, aborts if the ranges overlap.
305  void unionWith(LiveRange const &otherRange) {
306  for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
307  ++it)
308  ranges->insert(it.start(), it.stop(), kValidLiveRange);
309  values.set_union(otherRange.values);
310  }
311 
312  /// Inserts an interval [start, end) for `value` into this range.
313  void insert(Value value, unsigned start, unsigned end) {
314  values.insert(value);
315  if (start != end)
316  ranges->insert(start, end, kValidLiveRange);
317  }
318 
319  bool empty() const { return ranges->empty(); }
320  unsigned start() const { return ranges->start(); }
321  unsigned end() const { return ranges->stop(); }
322  bool operator<(LiveRange const &other) const {
323  return start() < other.start();
324  }
325 
326  ArmSMETileType getTileType() const {
327  return *getSMETileType(cast<VectorType>(values[0].getType()));
328  }
329 
330  /// The values contained in this live range.
331  SetVector<Value> values;
332 
333  /// A set of (non-overlapping) intervals that mark where any value in `values`
334  /// is live.
335  std::unique_ptr<RangeSet> ranges;
336 
337  /// The tile ID (or none) assigned to this live range.
338  std::optional<unsigned> tileId;
339 };
340 
341 /// Number operations within a function to allow computing live ranges.
342 /// Operations are numbered consecutively wihin blocks, and the blocks are
343 /// topologically sorted (using forward edges). This function is only correct if
344 /// all ArmSME have been converted to CF (which is asserted).
346 generateOperationNumbering(FunctionOpInterface function) {
347  unsigned index = 0;
348  SetVector<Block *> blocks =
349  getBlocksSortedByDominance(function.getFunctionBody());
350  DenseMap<Operation *, unsigned> operationToIndexMap;
351  for (Block *block : blocks) {
352  index++; // We want block args to have their own number.
353  for (Operation &op : block->getOperations()) {
354 #ifndef NDEBUG
355  op.walk([&](ArmSMETileOpInterface nestedOp) {
356  assert(&op == nestedOp.getOperation() &&
357  "ArmSME tile allocation does not support nested regions");
358  });
359 #endif
360  operationToIndexMap.try_emplace(&op, index++);
361  }
362  }
363  return operationToIndexMap;
364 }
365 
366 /// Gather live ranges for SME tiles from the MLIR liveness analysis.
368 gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
369  LiveRange::Allocator &liveRangeAllocator,
370  Liveness &liveness, FunctionOpInterface function) {
371  assert(!operationToIndexMap.empty() && "expected operation numbering");
372  DenseMap<Value, LiveRange> liveRanges;
373  /// Defines or updates a live range for an SME tile value. Live-ins may update
374  /// an existing live range (rather than define a new one). Note: If
375  /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
376  /// the block.
377  auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
378  LivenessBlockInfo const &livenessInfo,
379  bool liveAtBlockEntry = false) {
380  if (!isValidSMETileVectorType(value.getType()))
381  return;
382  // Find or create a live range for `value`.
383  auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
384  LiveRange &valueLiveRange = it->second;
385  auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
386  // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
387  unsigned startOpIdx =
388  operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
389  unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
390  valueLiveRange.insert(value, startOpIdx, endOpIdx);
391  };
392 
393  for (Block &block : function.getBlocks()) {
394  LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
395  // Handle block arguments:
396  for (Value argument : block.getArguments())
397  defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
398  /*liveAtBlockEntry=*/true);
399  // Handle live-ins:
400  for (Value liveIn : livenessInfo->in())
401  defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
402  /*liveAtBlockEntry=*/true);
403  // Handle new definitions:
404  for (Operation &op : block) {
405  for (Value result : op.getResults())
406  defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
407  }
408  }
409 
410  return liveRanges;
411 }
412 
413 /// Iterate over all predecessor tile values to a (tile) block argument.
414 static void forEachPredecessorTileValue(BlockArgument blockArg,
415  function_ref<void(Value)> callback) {
416  Block *block = blockArg.getOwner();
417  unsigned argNumber = blockArg.getArgNumber();
418  for (Block *pred : block->getPredecessors()) {
419  TypeSwitch<Operation *>(pred->getTerminator())
420  .Case<cf::BranchOp>([&](auto branch) {
421  Value predecessorOperand = branch.getDestOperands()[argNumber];
422  callback(predecessorOperand);
423  })
424  .Case<cf::CondBranchOp>([&](auto condBranch) {
425  if (condBranch.getFalseDest() == block) {
426  Value predecessorOperand =
427  condBranch.getFalseDestOperands()[argNumber];
428  callback(predecessorOperand);
429  }
430  if (condBranch.getTrueDest() == block) {
431  Value predecessorOperand =
432  condBranch.getTrueDestOperands()[argNumber];
433  callback(predecessorOperand);
434  }
435  });
436  }
437 }
438 
439 /// Coalesce live ranges where it would prevent unnecessary tile moves.
441 coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
442  DenseMap<Value, LiveRange *> liveRanges;
443  for (auto &[value, liveRange] : initialLiveRanges) {
444  liveRanges.insert({value, &liveRange});
445  }
446 
447  // Merge the live ranges of values `a` and `b` into one (if they do not
448  // overlap). After this, the values `a` and `b` will both point to the same
449  // live range (which will contain multiple values).
450  auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
451  LiveRange *aLiveRange = liveRanges.at(a);
452  LiveRange *bLiveRange = liveRanges.at(b);
453  if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
454  aLiveRange->unionWith(*bLiveRange);
455  for (Value value : bLiveRange->values)
456  liveRanges[value] = aLiveRange;
457  }
458  };
459 
460  // Merge the live ranges of new definitions with their tile operands.
461  auto unifyDefinitionsWithOperands = [&](Value value) {
462  auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
463  if (!armSMEOp)
464  return;
465  for (auto operand : armSMEOp->getOperands()) {
466  if (isValidSMETileVectorType(operand.getType()))
467  mergeValuesIfNonOverlapping(value, operand);
468  }
469  };
470 
471  // Merge the live ranges of block arguments with their predecessors.
472  auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
473  auto blockArg = dyn_cast<BlockArgument>(value);
474  if (!blockArg)
475  return;
476  forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
477  mergeValuesIfNonOverlapping(blockArg, predecessorTile);
478  });
479  };
480 
481  auto applyRule = [&](auto rule) {
482  llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
483  };
484 
485  // Unify as many live ranges as we can. This prevents unnecessary moves.
486  applyRule(unifyBlockArgumentsWithPredecessors);
487  applyRule(unifyDefinitionsWithOperands);
488 
489  // Remove duplicate live range entries.
490  SetVector<LiveRange *> uniqueLiveRanges;
491  for (auto [_, liveRange] : liveRanges) {
492  if (!liveRange->empty())
493  uniqueLiveRanges.insert(liveRange);
494  }
495 
496  // Sort the new live ranges by starting point (ready for tile allocation).
497  auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
498  llvm::sort(coalescedLiveRanges,
499  [](LiveRange *a, LiveRange *b) { return *a < *b; });
500  return std::move(coalescedLiveRanges);
501 }
502 
503 /// Choose a live range to spill (via some heuristics). This picks either a live
504 /// range from `overlappingRanges`, or the new live range `newRange`.
505 template <typename OverlappingRangesIterator>
506 LiveRange *
507 chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
508  LiveRange *newRange) {
509  // Heuristic: Spill trivially copyable operations (usually free).
510  auto isTrivialSpill = [&](LiveRange &allocatedRange) {
511  return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
512  newRange->getTileType()) &&
513  allocatedRange.values.size() == 1 &&
515  allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
516  };
517  if (isTrivialSpill(*newRange))
518  return newRange;
519  auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
520  if (trivialSpill != overlappingRanges.end())
521  return &*trivialSpill;
522 
523  // Heuristic: Spill the range that ends last (with a compatible tile type).
524  auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
525  return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
526  a.end() < b.end();
527  };
528  LiveRange &latestEndingLiveRange =
529  *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
530  if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
531  return &latestEndingLiveRange;
532  return newRange;
533 }
534 
535 /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
536 void allocateTilesToLiveRanges(
537  ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
538  TileAllocator tileAllocator;
539  // `activeRanges` = Live ranges that need to be in a tile at the
540  // `currentPoint` in the program.
541  SetVector<LiveRange *> activeRanges;
542  // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
543  // at the `currentPoint` in the program but could become active again later.
544  // An inactive section of a live range can be seen as a 'hole' in the live
545  // range, where it is possible to reuse the live range's tile ID _before_ it
546  // has ended. By identifying 'holes', the allocator can reuse tiles more
547  // often, which helps avoid costly tile spills.
548  SetVector<LiveRange *> inactiveRanges;
549  for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
550  auto currentPoint = nextRange->start();
551  // 1. Update the `activeRanges` at `currentPoint`.
552  activeRanges.remove_if([&](LiveRange *activeRange) {
553  // Check for live ranges that have expired.
554  if (activeRange->end() <= currentPoint) {
555  tileAllocator.releaseTileId(activeRange->getTileType(),
556  *activeRange->tileId);
557  return true;
558  }
559  // Check for live ranges that have become inactive.
560  if (!activeRange->overlaps(currentPoint)) {
561  tileAllocator.releaseTileId(activeRange->getTileType(),
562  *activeRange->tileId);
563  inactiveRanges.insert(activeRange);
564  return true;
565  }
566  return false;
567  });
568  // 2. Update the `inactiveRanges` at `currentPoint`.
569  inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
570  // Check for live ranges that have expired.
571  if (inactiveRange->end() <= currentPoint) {
572  return true;
573  }
574  // Check for live ranges that have become active.
575  if (inactiveRange->overlaps(currentPoint)) {
576  tileAllocator.acquireTileId(inactiveRange->getTileType(),
577  *inactiveRange->tileId);
578  activeRanges.insert(inactiveRange);
579  return true;
580  }
581  return false;
582  });
583 
584  // 3. Collect inactive live ranges that overlap with the new live range.
585  // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
586  // whereas this checks if there is an overlap at any future point too.
587  SmallVector<LiveRange *> overlappingInactiveRanges;
588  for (LiveRange *inactiveRange : inactiveRanges) {
589  if (inactiveRange->overlaps(*nextRange)) {
590  // We need to reserve the tile IDs of overlapping inactive ranges to
591  // prevent two (overlapping) live ranges from getting the same tile ID.
592  tileAllocator.acquireTileId(inactiveRange->getTileType(),
593  *inactiveRange->tileId);
594  overlappingInactiveRanges.push_back(inactiveRange);
595  }
596  }
597 
598  // 4. Allocate a tile ID to `nextRange`.
599  auto rangeTileType = nextRange->getTileType();
600  auto tileId = tileAllocator.allocateTileId(rangeTileType);
601  if (succeeded(tileId)) {
602  nextRange->tileId = *tileId;
603  } else {
604  // Create an iterator over all overlapping live ranges.
605  auto allOverlappingRanges = llvm::concat<LiveRange>(
606  llvm::make_pointee_range(activeRanges.getArrayRef()),
607  llvm::make_pointee_range(overlappingInactiveRanges));
608  // Choose an overlapping live range to spill.
609  LiveRange *rangeToSpill =
610  chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
611  if (rangeToSpill != nextRange) {
612  // Spill an (in)active live range (so release its tile ID first).
613  tileAllocator.releaseTileId(rangeToSpill->getTileType(),
614  *rangeToSpill->tileId);
615  // This will always succeed after a spill (of an active live range).
616  nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
617  // Remove the live range from the active/inactive sets.
618  if (!activeRanges.remove(rangeToSpill)) {
619  bool removed = inactiveRanges.remove(rangeToSpill);
620  assert(removed && "expected a range to be removed!");
621  (void)removed;
622  }
623  }
624  rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
625  }
626 
627  // 5. Insert the live range into the active ranges.
628  if (nextRange->tileId < kInMemoryTileIdBase)
629  activeRanges.insert(nextRange);
630 
631  // 6. Release tiles reserved for inactive live ranges (in step 3).
632  for (LiveRange *range : overlappingInactiveRanges) {
633  if (*range->tileId < kInMemoryTileIdBase)
634  tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
635  }
636  }
637 }
638 
639 /// Assigns a tile ID to an MLIR value.
640 void assignTileIdToValue(IRRewriter &rewriter, Value value,
641  IntegerAttr tileIdAttr) {
642  if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
643  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
644  for (Operation *user : value.getUsers()) {
645  if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
646  // Ensure ArmSME ops that don't produce a value still get a tile ID.
647  if (!hasTileResult(tileOp))
648  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
649  }
650  }
651 }
652 
653 /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
654 LogicalResult assignTileIdsAndResolveTrivialConflicts(
655  IRRewriter &rewriter, FunctionOpInterface function,
656  ArrayRef<LiveRange *> allocatedLiveRanges) {
657  for (LiveRange const *liveRange : allocatedLiveRanges) {
658  auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
659  auto isAllocatedToSameTile = [&](Value value) {
660  if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
661  tileOp && tileOp.getTileId() == tileIdAttr)
662  return true;
663  return liveRange->values.contains(value);
664  };
665 
666  /// Eliminates copies where the operand has the same tile ID.
667  auto foldRedundantCopies = [&](Value value) -> LogicalResult {
668  auto copyOp = value.getDefiningOp<CopyTileOp>();
669  if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
670  return failure();
671  rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
672  return success();
673  };
674 
675  /// Validates each predecessor to a tile block argument has been assigned
676  /// the same tile ID.
677  auto validateBlockArguments = [&](Value value) {
678  auto blockArg = dyn_cast<BlockArgument>(value);
679  if (!blockArg) {
680  // Not a block argument (nothing to validate).
681  return success();
682  }
683  bool tileMismatch = false;
684  forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
685  if (tileMismatch)
686  return;
687  if (!isAllocatedToSameTile(predecessorTile)) {
688  blockArg.getOwner()->getParentOp()->emitOpError(
689  "block argument not allocated to the same SME virtial tile as "
690  "predecessors");
691  tileMismatch = true;
692  }
693  });
694  return success(/*isSuccess=*/!tileMismatch);
695  };
696 
697  /// Attempts to resolve (trivial) tile ID conflicts.
698  auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
699  auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
700  OpOperand *tileOperand = getTileOpOperand(tileOp);
701  if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
702  // Operand already allocated to the correct tile.
703  // No conflict to resolve.
704  return success();
705  }
706  auto operandTileOp =
707  tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
708  if (!isTriviallyCloneableTileOp(operandTileOp)) {
709  auto error =
710  tileOp.emitOpError("tile operand allocated to different SME "
711  "virtial tile (move required)");
712  error.attachNote(tileOperand->get().getLoc())
713  << "tile operand is: " << tileOperand->get();
714  return error;
715  }
716  // Cloning prevents a move/spill (though may require recomputation).
717  rewriter.setInsertionPoint(tileOp);
718  auto clonedOp = operandTileOp.clone();
719  rewriter.modifyOpInPlace(clonedOp,
720  [&] { clonedOp.setTileId(tileOp.getTileId()); });
721  rewriter.insert(clonedOp);
722  if (isa<CopyTileOp>(tileOp)) {
723  rewriter.replaceAllUsesWith(tileOp->getResult(0),
724  clonedOp->getResult(0));
725  } else {
726  rewriter.modifyOpInPlace(
727  tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
728  }
729  return success();
730  };
731 
732  for (Value value : liveRange->values) {
733  // 1. Assign the tile ID to the value.
734  assignTileIdToValue(rewriter, value, tileIdAttr);
735 
736  // 2. Attempt to eliminate redundant tile copies.
737  if (succeeded(foldRedundantCopies(value)))
738  continue;
739 
740  // 3. Validate tile block arguments.
741  if (failed(validateBlockArguments(value)))
742  return failure();
743 
744  // 4. Attempt to resolve (trivial) tile ID conflicts.
745  if (failed(resolveTrivialTileConflicts(value)))
746  return failure();
747  }
748  }
749  return success();
750 }
751 
752 /// Prints live ranges alongside operation names for debugging.
753 void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
754  ArrayRef<LiveRange const *> liveRanges,
755  FunctionOpInterface function) {
756  llvm::errs() << "SME Tile Liveness: @" << function.getName()
757  << "\nKey:\nS - Start\nE - End\n| - Live\n";
758  for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
759  llvm::errs() << "^bb" << blockIdx << ":\n";
760  for (Operation &op : block.getOperations()) {
761  unsigned operationIndex = operationToIndexMap.at(&op);
762  for (LiveRange const *range : liveRanges) {
763  char liveness = ' ';
764  for (auto it = range->ranges->begin(); it != range->ranges->end();
765  ++it) {
766  if (it.start() == operationIndex)
767  liveness = (liveness == 'E' ? '|' : 'S');
768  else if (it.stop() == operationIndex)
769  liveness = (liveness == 'S' ? '|' : 'E');
770  else if (operationIndex >= it.start() && operationIndex < it.stop())
771  liveness = '|';
772  }
773  llvm::errs() << liveness;
774  }
775  llvm::errs() << ' ' << op.getName() << '\n';
776  }
777  }
778  llvm::errs() << "==========\n";
779 }
780 
781 struct TestTileAllocationPass
782  : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
783  using TestTileAllocationBase::TestTileAllocationBase;
784  void runOnOperation() override {
785  FunctionOpInterface function = getOperation();
786  if (preprocessOnly) {
787  IRRewriter rewriter(function);
788  return preprocessForTileAllocation(rewriter, function);
789  }
790  if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
791  signalPassFailure();
792  }
793 };
794 } // namespace
795 
796 LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
797  bool dumpRanges) {
798  if (function.empty()) {
799  // TODO: Also return early if the function contains no ArmSME ops?
800  return success();
801  }
802 
803  LiveRange::Allocator liveRangeAllocator;
804  IRRewriter rewriter(function.getContext());
805 
806  // 1. Preprocess the IR for tile allocation.
807  preprocessForTileAllocation(rewriter, function);
808 
809  // 2. Gather live ranges for each ArmSME tile within the function.
810  Liveness liveness(function);
811  auto operationToIndexMap = generateOperationNumbering(function);
812  auto initialLiveRanges = gatherTileLiveRanges(
813  operationToIndexMap, liveRangeAllocator, liveness, function);
814  if (initialLiveRanges.empty())
815  return success();
816 
817  if (dumpRanges) {
818  // Wrangle initial live ranges into a form suitable for printing.
819  auto nonEmpty = llvm::make_filter_range(
820  llvm::make_second_range(initialLiveRanges),
821  [&](LiveRange const &liveRange) { return !liveRange.empty(); });
822  auto initialRanges = llvm::to_vector(llvm::map_range(
823  nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
824  llvm::sort(initialRanges,
825  [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
826  llvm::errs() << "\n========== Initial Live Ranges:\n";
827  dumpLiveRanges(operationToIndexMap, initialRanges, function);
828  }
829 
830  // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
831  // for tile allocation. E.g. Unify the result of an operation with its
832  // operands.
833  auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
834 
835  if (dumpRanges) {
836  llvm::errs() << "\n========== Coalesced Live Ranges:\n";
837  dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
838  }
839 
840  // 4. Allocate tile IDs to live ranges.
841  allocateTilesToLiveRanges(coalescedLiveRanges);
842 
843  // 5. Assign the tile IDs back to the ArmSME operations.
844  if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
845  coalescedLiveRanges))) {
846  return failure();
847  }
848 
849  // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
850  // users). This prevents the LLVM conversion needlessly inserting spills.
851  eraseTriviallyDeadTileOps(rewriter, function);
852  return success();
853 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:240
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
Definition: Liveness.cpp:225
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:416
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:43
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition: Utils.cpp:118
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition: Utils.cpp:138
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition: Utils.cpp:166
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:28
static constexpr unsigned kInMemoryTileIdBase
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition: Utils.cpp:151
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition: Utils.cpp:143
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool operator<(const Fraction &x, const Fraction &y)
Definition: Fraction.h:83
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304