MLIR  20.0.0git
TileAllocation.cpp
Go to the documentation of this file.
1 //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transform allocates SME tiles at the 'func.func' op level for ArmSME
10 // operations. It roughly implements a linear scan register allocator, similar
11 // to the one outlined in [1], but with simplifications and assumptions made for
12 // our use case. Note that this is a greedy allocator (so it may not always find
13 // the most optimal allocation of tiles).
14 //
15 // The allocator operates at the CF dialect level. It is the responsibility of
16 // users to ensure the IR has been lowered to CF before invoking the tile
17 // allocator.
18 //
19 // The 128-bit tiles overlap with other element tiles as follows (see section
20 // B2.3.2 of SME spec [2]):
21 //
22 // Tile Overlaps
23 // ---------------------------------------------------------------------------
24 // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
25 // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
26 // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
27 // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
28 // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
29 // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
30 // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
31 // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
32 // ZA0.D ZA0.Q, ZA8.Q
33 // ZA1.D ZA1.Q, ZA9.Q
34 // ZA2.D ZA2.Q, ZA10.Q
35 // ZA3.D ZA3.Q, ZA11.Q
36 // ZA4.D ZA4.Q, ZA12.Q
37 // ZA5.D ZA5.Q, ZA13.Q
38 // ZA6.D ZA6.Q, ZA14.Q
39 // ZA7.D ZA7.Q, ZA15.Q
40 //
41 // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42 // Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43 // https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44 // [2] https://developer.arm.com/documentation/ddi0616/aa
45 //
46 //===----------------------------------------------------------------------===//
47 
48 #include "mlir/Analysis/Liveness.h"
56 #include "llvm/ADT/IntervalMap.h"
57 #include "llvm/ADT/TypeSwitch.h"
58 #include <algorithm>
59 
60 namespace mlir::arm_sme {
61 #define GEN_PASS_DEF_TESTTILEALLOCATION
62 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
63 } // namespace mlir::arm_sme
64 
65 using namespace mlir;
66 using namespace mlir::arm_sme;
67 
68 namespace {
69 
70 enum class TileMask : unsigned {
71  // clang-format off
72  kZA0B = 0xffff, // 1111 1111 1111 1111
73 
74  kZA0H = 0xaaaa, // 1010 1010 1010 1010
75  kZA1H = 0x5555, // 0101 0101 0101 0101
76 
77  kZA0S = 0x8888, // 1000 1000 1000 1000
78  kZA1S = 0x4444, // 0100 0100 0100 0100
79  kZA2S = 0x2222, // 0010 0010 0010 0010
80  kZA3S = 0x1111, // 0001 0001 0001 0001
81 
82  kZA0D = 0x8080, // 1000 0000 1000 0000
83  kZA1D = 0x4040, // 0100 0000 0100 0000
84  kZA2D = 0x2020, // 0010 0000 0010 0000
85  kZA3D = 0x1010, // 0001 0000 0001 0000
86  kZA4D = 0x808, // 0000 1000 0000 1000
87  kZA5D = 0x404, // 0000 0100 0000 0100
88  kZA6D = 0x202, // 0000 0010 0000 0010
89  kZA7D = 0x101, // 0000 0001 0000 0001
90 
91  kZA0Q = 0x8000, // 1000 0000 0000 0000
92  kZA1Q = 0x4000, // 0100 0000 0000 0000
93  kZA2Q = 0x2000, // 0010 0000 0000 0000
94  kZA3Q = 0x1000, // 0001 0000 0000 0000
95  kZA4Q = 0x800, // 0000 1000 0000 0000
96  kZA5Q = 0x400, // 0000 0100 0000 0000
97  kZA6Q = 0x200, // 0000 0010 0000 0000
98  kZA7Q = 0x100, // 0000 0001 0000 0000
99  kZA8Q = 0x80, // 0000 0000 1000 0000
100  kZA9Q = 0x40, // 0000 0000 0100 0000
101  kZA10Q = 0x20, // 0000 0000 0010 0000
102  kZA11Q = 0x10, // 0000 0000 0001 0000
103  kZA12Q = 0x8, // 0000 0000 0000 1000
104  kZA13Q = 0x4, // 0000 0000 0000 0100
105  kZA14Q = 0x2, // 0000 0000 0000 0010
106  kZA15Q = 0x1, // 0000 0000 0000 0001
107 
108  kNone = 0x0, // 0000 0000 0000 0000
109  // clang-format on
110 
112 };
113 
114 /// Returns the set of masks relevant for the given type.
115 static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
116  static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
117  static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
118  static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
119  TileMask::kZA2S, TileMask::kZA3S};
120  static constexpr std::array ZA_D_MASKS = {
121  TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
122  TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
123  static constexpr std::array ZA_Q_MASKS = {
124  TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
125  TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
126  TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
127  TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
128  switch (type) {
129  case ArmSMETileType::ZAB:
130  return ZA_B_MASKS;
131  case ArmSMETileType::ZAH:
132  return ZA_H_MASKS;
133  case ArmSMETileType::ZAS:
134  return ZA_S_MASKS;
135  case ArmSMETileType::ZAD:
136  return ZA_D_MASKS;
137  case ArmSMETileType::ZAQ:
138  return ZA_Q_MASKS;
139  }
140  llvm_unreachable("unknown type in getMasks");
141 }
142 
143 class TileAllocator {
144 public:
145  /// Allocates and returns a tile ID. Fails if there are no tiles left.
146  FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
147  auto masks = getMasks(tileType);
148  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
149  if ((tilesInUse & tileMask) == TileMask::kNone) {
150  tilesInUse |= tileMask;
151  return tileId;
152  }
153  }
154  return failure();
155  }
156 
157  /// Acquires a specific tile ID. Asserts the tile is initially free.
158  void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
159  TileMask tileMask = getMasks(tileType)[tileId];
160  assert((tilesInUse & tileMask) == TileMask::kNone &&
161  "cannot acquire allocated tile!");
162  tilesInUse |= tileMask;
163  }
164 
165  /// Releases a previously allocated tile ID.
166  void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
167  TileMask tileMask = getMasks(tileType)[tileId];
168  assert((tilesInUse & tileMask) == tileMask &&
169  "cannot release unallocated tile!");
170  tilesInUse ^= tileMask;
171  }
172 
173  /// Allocates an in-memory tile ID.
174  unsigned allocateInMemoryTileId() {
175  // Note: We never release in-memory tile IDs. We could, which may allow
176  // reusing an allocation, but as we _never_ want to spill an SME tile this
177  // is not optimized.
178  return nextInMemoryTileId++;
179  }
180 
181 private:
182  TileMask tilesInUse = TileMask::kNone;
183  unsigned nextInMemoryTileId = kInMemoryTileIdBase;
184 };
185 
186 /// Add new intermediate blocks for the true and false destinations of
187 /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
188 /// overlaps due to copies at branches.
189 ///
190 /// BEFORE:
191 /// ```mlir
192 /// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
193 /// ```
194 ///
195 /// AFTER:
196 /// ```mlir
197 /// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
198 /// ^bb1_copy:
199 /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
200 /// ^bb2_copy:
201 /// cf.br ^bb2
202 /// ```
203 void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
205  function.walk([&](cf::CondBranchOp condBranch) {
206  if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
207  return isValidSMETileVectorType(value.getType());
208  })) {
209  worklist.push_back(condBranch);
210  }
211  });
212 
213  auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
214  rewriter.setInsertionPointToEnd(source);
215  rewriter.create<cf::BranchOp>(loc, dest, args);
216  };
217 
218  for (auto condBranch : worklist) {
219  auto loc = condBranch.getLoc();
220  Block *block = condBranch->getBlock();
221  auto newTrueBranch = rewriter.splitBlock(block, block->end());
222  auto newFalseBranch = rewriter.splitBlock(block, block->end());
223  insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
224  condBranch.getTrueDestOperands());
225  insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
226  condBranch.getFalseDestOperands());
227  rewriter.modifyOpInPlace(condBranch, [&] {
228  condBranch.getFalseDestOperandsMutable().clear();
229  condBranch.getTrueDestOperandsMutable().clear();
230  condBranch.setSuccessor(newTrueBranch, 0);
231  condBranch.setSuccessor(newFalseBranch, 1);
232  });
233  }
234 }
235 
236 /// Inserts tile copies at `cf.br` operations.
237 ///
238 /// BEFORE:
239 /// ```mlir
240 /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
241 /// ```
242 ///
243 /// AFTER:
244 /// ```mlir
245 /// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
246 /// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
247 /// ```
248 void insertCopiesAtBranches(IRRewriter &rewriter,
249  FunctionOpInterface function) {
250  for (Block &block : function.getBlocks()) {
251  Operation *terminator = block.getTerminator();
252  if (!isa<cf::BranchOp>(terminator))
253  continue;
254  rewriter.setInsertionPoint(terminator);
255  for (OpOperand &operand : terminator->getOpOperands()) {
256  if (isValidSMETileVectorType(operand.get().getType())) {
257  auto copy =
258  rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
259  rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
260  }
261  }
262  }
263 }
264 
265 /// Prepares the IR for tile allocation. It does this by first 'splitting'
266 /// conditional branches (see `splitCondBranches`), then inserting tile copies
267 /// at branch operations. The conditional branches are split to prevent the
268 /// copies needed for them overlapping between the true and false paths of the
269 /// branch (see `tile-allocation-copies.mlir` and
270 /// `tile-allocation-liveness.mlir` for examples). The copies break up live
271 /// ranges and ensure when moving out of SSA the semantics of the program are
272 /// preserved.
273 void preprocessForTileAllocation(IRRewriter &rewriter,
274  FunctionOpInterface function) {
275  splitCondBranches(rewriter, function);
276  insertCopiesAtBranches(rewriter, function);
277 }
278 
279 /// A live range for a (collection of) tile values. A live range is built up of
280 /// non-overlapping intervals [start, end) which represent parts of the program
281 /// where a value in the range needs to be live (i.e. in an SME virtual tile).
282 /// Note that as the intervals are non-overlapping all values within a live
283 /// range can be allocated to the same SME virtual tile.
284 struct LiveRange {
285  using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
286  llvm::IntervalMapHalfOpenInfo<unsigned>>;
287  using Allocator = RangeSet::Allocator;
288  // Dummy value for the IntervalMap. Only the keys matter (the intervals).
289  static constexpr uint8_t kValidLiveRange = 0xff;
290 
291  LiveRange(Allocator &allocator)
292  : ranges(std::make_unique<RangeSet>(allocator)) {}
293 
294  /// Returns true if this range overlaps with `otherRange`.
295  bool overlaps(LiveRange const &otherRange) const {
296  return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
297  *otherRange.ranges)
298  .valid();
299  }
300 
301  /// Returns true if this range is active at `point` in the program.
302  bool overlaps(uint64_t point) const {
303  return ranges->lookup(point) == kValidLiveRange;
304  }
305 
306  /// Unions this live range with `otherRange`, aborts if the ranges overlap.
307  void unionWith(LiveRange const &otherRange) {
308  for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
309  ++it)
310  ranges->insert(it.start(), it.stop(), kValidLiveRange);
311  values.set_union(otherRange.values);
312  }
313 
314  /// Inserts an interval [start, end) for `value` into this range.
315  void insert(Value value, unsigned start, unsigned end) {
316  values.insert(value);
317  if (start != end)
318  ranges->insert(start, end, kValidLiveRange);
319  }
320 
321  bool empty() const { return ranges->empty(); }
322  unsigned start() const { return ranges->start(); }
323  unsigned end() const { return ranges->stop(); }
324  bool operator<(LiveRange const &other) const {
325  return start() < other.start();
326  }
327 
328  ArmSMETileType getTileType() const {
329  return *getSMETileType(cast<VectorType>(values[0].getType()));
330  }
331 
332  /// The values contained in this live range.
333  SetVector<Value> values;
334 
335  /// A set of (non-overlapping) intervals that mark where any value in `values`
336  /// is live.
337  std::unique_ptr<RangeSet> ranges;
338 
339  /// The tile ID (or none) assigned to this live range.
340  std::optional<unsigned> tileId;
341 };
342 
343 /// Number operations within a function to allow computing live ranges.
344 /// Operations are numbered consecutively wihin blocks, and the blocks are
345 /// topologically sorted (using forward edges). This function is only correct if
346 /// all ArmSME have been converted to CF (which is asserted).
348 generateOperationNumbering(FunctionOpInterface function) {
349  unsigned index = 0;
350  SetVector<Block *> blocks =
351  getBlocksSortedByDominance(function.getFunctionBody());
352  DenseMap<Operation *, unsigned> operationToIndexMap;
353  for (Block *block : blocks) {
354  index++; // We want block args to have their own number.
355  for (Operation &op : block->getOperations()) {
356 #ifndef NDEBUG
357  op.walk([&](ArmSMETileOpInterface nestedOp) {
358  assert(&op == nestedOp.getOperation() &&
359  "ArmSME tile allocation does not support nested regions");
360  });
361 #endif
362  operationToIndexMap.try_emplace(&op, index++);
363  }
364  }
365  return operationToIndexMap;
366 }
367 
368 /// Gather live ranges for SME tiles from the MLIR liveness analysis.
370 gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
371  LiveRange::Allocator &liveRangeAllocator,
372  Liveness &liveness, FunctionOpInterface function) {
373  assert(!operationToIndexMap.empty() && "expected operation numbering");
374  DenseMap<Value, LiveRange> liveRanges;
375  /// Defines or updates a live range for an SME tile value. Live-ins may update
376  /// an existing live range (rather than define a new one). Note: If
377  /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
378  /// the block.
379  auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
380  LivenessBlockInfo const &livenessInfo,
381  bool liveAtBlockEntry = false) {
382  if (!isValidSMETileVectorType(value.getType()))
383  return;
384  // Find or create a live range for `value`.
385  auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
386  LiveRange &valueLiveRange = it->second;
387  auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
388  // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
389  unsigned startOpIdx =
390  operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
391  unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
392  valueLiveRange.insert(value, startOpIdx, endOpIdx);
393  };
394 
395  for (Block &block : function.getBlocks()) {
396  LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
397  // Handle block arguments:
398  for (Value argument : block.getArguments())
399  defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
400  /*liveAtBlockEntry=*/true);
401  // Handle live-ins:
402  for (Value liveIn : livenessInfo->in())
403  defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
404  /*liveAtBlockEntry=*/true);
405  // Handle new definitions:
406  for (Operation &op : block) {
407  for (Value result : op.getResults())
408  defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
409  }
410  }
411 
412  return liveRanges;
413 }
414 
415 /// Iterate over all predecessor tile values to a (tile) block argument.
416 static void forEachPredecessorTileValue(BlockArgument blockArg,
417  function_ref<void(Value)> callback) {
418  Block *block = blockArg.getOwner();
419  unsigned argNumber = blockArg.getArgNumber();
420  for (Block *pred : block->getPredecessors()) {
421  TypeSwitch<Operation *>(pred->getTerminator())
422  .Case<cf::BranchOp>([&](auto branch) {
423  Value predecessorOperand = branch.getDestOperands()[argNumber];
424  callback(predecessorOperand);
425  })
426  .Case<cf::CondBranchOp>([&](auto condBranch) {
427  if (condBranch.getFalseDest() == block) {
428  Value predecessorOperand =
429  condBranch.getFalseDestOperands()[argNumber];
430  callback(predecessorOperand);
431  }
432  if (condBranch.getTrueDest() == block) {
433  Value predecessorOperand =
434  condBranch.getTrueDestOperands()[argNumber];
435  callback(predecessorOperand);
436  }
437  });
438  }
439 }
440 
441 /// Coalesce live ranges where it would prevent unnecessary tile moves.
443 coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
444  DenseMap<Value, LiveRange *> liveRanges;
445  for (auto &[value, liveRange] : initialLiveRanges) {
446  liveRanges.insert({value, &liveRange});
447  }
448 
449  // Merge the live ranges of values `a` and `b` into one (if they do not
450  // overlap). After this, the values `a` and `b` will both point to the same
451  // live range (which will contain multiple values).
452  auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
453  LiveRange *aLiveRange = liveRanges.at(a);
454  LiveRange *bLiveRange = liveRanges.at(b);
455  if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
456  aLiveRange->unionWith(*bLiveRange);
457  for (Value value : bLiveRange->values)
458  liveRanges[value] = aLiveRange;
459  }
460  };
461 
462  // Merge the live ranges of new definitions with their tile operands.
463  auto unifyDefinitionsWithOperands = [&](Value value) {
464  auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
465  if (!armSMEOp)
466  return;
467  for (auto operand : armSMEOp->getOperands()) {
468  if (isValidSMETileVectorType(operand.getType()))
469  mergeValuesIfNonOverlapping(value, operand);
470  }
471  };
472 
473  // Merge the live ranges of block arguments with their predecessors.
474  auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
475  auto blockArg = dyn_cast<BlockArgument>(value);
476  if (!blockArg)
477  return;
478  forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
479  mergeValuesIfNonOverlapping(blockArg, predecessorTile);
480  });
481  };
482 
483  auto applyRule = [&](auto rule) {
484  llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
485  };
486 
487  // Unify as many live ranges as we can. This prevents unnecessary moves.
488  applyRule(unifyBlockArgumentsWithPredecessors);
489  applyRule(unifyDefinitionsWithOperands);
490 
491  // Remove duplicate live range entries.
492  SetVector<LiveRange *> uniqueLiveRanges;
493  for (auto [_, liveRange] : liveRanges) {
494  if (!liveRange->empty())
495  uniqueLiveRanges.insert(liveRange);
496  }
497 
498  // Sort the new live ranges by starting point (ready for tile allocation).
499  auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
500  std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
501  [](LiveRange *a, LiveRange *b) { return *a < *b; });
502  return std::move(coalescedLiveRanges);
503 }
504 
505 /// Choose a live range to spill (via some heuristics). This picks either a live
506 /// range from `overlappingRanges`, or the new live range `newRange`.
507 template <typename OverlappingRangesIterator>
508 LiveRange *
509 chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
510  LiveRange *newRange) {
511  // Heuristic: Spill trivially copyable operations (usually free).
512  auto isTrivialSpill = [&](LiveRange &allocatedRange) {
513  return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
514  newRange->getTileType()) &&
515  allocatedRange.values.size() == 1 &&
517  allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
518  };
519  if (isTrivialSpill(*newRange))
520  return newRange;
521  auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
522  if (trivialSpill != overlappingRanges.end())
523  return &*trivialSpill;
524 
525  // Heuristic: Spill the range that ends last (with a compatible tile type).
526  auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
527  return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
528  a.end() < b.end();
529  };
530  LiveRange &latestEndingLiveRange =
531  *std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
532  isSmallerTileTypeOrEndsEarlier);
533  if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
534  return &latestEndingLiveRange;
535  return newRange;
536 }
537 
538 /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
539 void allocateTilesToLiveRanges(
540  ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
541  TileAllocator tileAllocator;
542  // `activeRanges` = Live ranges that need to be in a tile at the
543  // `currentPoint` in the program.
544  SetVector<LiveRange *> activeRanges;
545  // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
546  // at the `currentPoint` in the program but could become active again later.
547  // An inactive section of a live range can be seen as a 'hole' in the live
548  // range, where it is possible to reuse the live range's tile ID _before_ it
549  // has ended. By identifying 'holes', the allocator can reuse tiles more
550  // often, which helps avoid costly tile spills.
551  SetVector<LiveRange *> inactiveRanges;
552  for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
553  auto currentPoint = nextRange->start();
554  // 1. Update the `activeRanges` at `currentPoint`.
555  activeRanges.remove_if([&](LiveRange *activeRange) {
556  // Check for live ranges that have expired.
557  if (activeRange->end() <= currentPoint) {
558  tileAllocator.releaseTileId(activeRange->getTileType(),
559  *activeRange->tileId);
560  return true;
561  }
562  // Check for live ranges that have become inactive.
563  if (!activeRange->overlaps(currentPoint)) {
564  tileAllocator.releaseTileId(activeRange->getTileType(),
565  *activeRange->tileId);
566  inactiveRanges.insert(activeRange);
567  return true;
568  }
569  return false;
570  });
571  // 2. Update the `inactiveRanges` at `currentPoint`.
572  inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
573  // Check for live ranges that have expired.
574  if (inactiveRange->end() <= currentPoint) {
575  return true;
576  }
577  // Check for live ranges that have become active.
578  if (inactiveRange->overlaps(currentPoint)) {
579  tileAllocator.acquireTileId(inactiveRange->getTileType(),
580  *inactiveRange->tileId);
581  activeRanges.insert(inactiveRange);
582  return true;
583  }
584  return false;
585  });
586 
587  // 3. Collect inactive live ranges that overlap with the new live range.
588  // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
589  // whereas this checks if there is an overlap at any future point too.
590  SmallVector<LiveRange *> overlappingInactiveRanges;
591  for (LiveRange *inactiveRange : inactiveRanges) {
592  if (inactiveRange->overlaps(*nextRange)) {
593  // We need to reserve the tile IDs of overlapping inactive ranges to
594  // prevent two (overlapping) live ranges from getting the same tile ID.
595  tileAllocator.acquireTileId(inactiveRange->getTileType(),
596  *inactiveRange->tileId);
597  overlappingInactiveRanges.push_back(inactiveRange);
598  }
599  }
600 
601  // 4. Allocate a tile ID to `nextRange`.
602  auto rangeTileType = nextRange->getTileType();
603  auto tileId = tileAllocator.allocateTileId(rangeTileType);
604  if (succeeded(tileId)) {
605  nextRange->tileId = *tileId;
606  } else {
607  // Create an iterator over all overlapping live ranges.
608  auto allOverlappingRanges = llvm::concat<LiveRange>(
609  llvm::make_pointee_range(activeRanges.getArrayRef()),
610  llvm::make_pointee_range(overlappingInactiveRanges));
611  // Choose an overlapping live range to spill.
612  LiveRange *rangeToSpill =
613  chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
614  if (rangeToSpill != nextRange) {
615  // Spill an (in)active live range (so release its tile ID first).
616  tileAllocator.releaseTileId(rangeToSpill->getTileType(),
617  *rangeToSpill->tileId);
618  // This will always succeed after a spill (of an active live range).
619  nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
620  // Remove the live range from the active/inactive sets.
621  if (!activeRanges.remove(rangeToSpill)) {
622  bool removed = inactiveRanges.remove(rangeToSpill);
623  assert(removed && "expected a range to be removed!");
624  (void)removed;
625  }
626  }
627  rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
628  }
629 
630  // 5. Insert the live range into the active ranges.
631  if (nextRange->tileId < kInMemoryTileIdBase)
632  activeRanges.insert(nextRange);
633 
634  // 6. Release tiles reserved for inactive live ranges (in step 3).
635  for (LiveRange *range : overlappingInactiveRanges) {
636  if (*range->tileId < kInMemoryTileIdBase)
637  tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
638  }
639  }
640 }
641 
642 /// Assigns a tile ID to an MLIR value.
643 void assignTileIdToValue(IRRewriter &rewriter, Value value,
644  IntegerAttr tileIdAttr) {
645  if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
646  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
647  for (Operation *user : value.getUsers()) {
648  if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
649  // Ensure ArmSME ops that don't produce a value still get a tile ID.
650  if (!hasTileResult(tileOp))
651  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
652  }
653  }
654 }
655 
656 /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
657 LogicalResult assignTileIdsAndResolveTrivialConflicts(
658  IRRewriter &rewriter, FunctionOpInterface function,
659  ArrayRef<LiveRange *> allocatedLiveRanges) {
660  for (LiveRange const *liveRange : allocatedLiveRanges) {
661  auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
662  auto isAllocatedToSameTile = [&](Value value) {
663  if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
664  tileOp && tileOp.getTileId() == tileIdAttr)
665  return true;
666  return liveRange->values.contains(value);
667  };
668 
669  /// Eliminates copies where the operand has the same tile ID.
670  auto foldRedundantCopies = [&](Value value) -> LogicalResult {
671  auto copyOp = value.getDefiningOp<CopyTileOp>();
672  if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
673  return failure();
674  rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
675  return success();
676  };
677 
678  /// Validates each predecessor to a tile block argument has been assigned
679  /// the same tile ID.
680  auto validateBlockArguments = [&](Value value) {
681  auto blockArg = dyn_cast<BlockArgument>(value);
682  if (!blockArg) {
683  // Not a block argument (nothing to validate).
684  return success();
685  }
686  bool tileMismatch = false;
687  forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
688  if (tileMismatch)
689  return;
690  if (!isAllocatedToSameTile(predecessorTile)) {
691  blockArg.getOwner()->getParentOp()->emitOpError(
692  "block argument not allocated to the same SME virtial tile as "
693  "predecessors");
694  tileMismatch = true;
695  }
696  });
697  return success(/*isSuccess=*/!tileMismatch);
698  };
699 
700  /// Attempts to resolve (trivial) tile ID conflicts.
701  auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
702  auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
703  OpOperand *tileOperand = getTileOpOperand(tileOp);
704  if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
705  // Operand already allocated to the correct tile.
706  // No conflict to resolve.
707  return success();
708  }
709  auto operandTileOp =
710  tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
711  if (!isTriviallyCloneableTileOp(operandTileOp)) {
712  auto error =
713  tileOp.emitOpError("tile operand allocated to different SME "
714  "virtial tile (move required)");
715  error.attachNote(tileOperand->get().getLoc())
716  << "tile operand is: " << tileOperand->get();
717  return error;
718  }
719  // Cloning prevents a move/spill (though may require recomputation).
720  rewriter.setInsertionPoint(tileOp);
721  auto clonedOp = operandTileOp.clone();
722  rewriter.modifyOpInPlace(clonedOp,
723  [&] { clonedOp.setTileId(tileOp.getTileId()); });
724  rewriter.insert(clonedOp);
725  if (isa<CopyTileOp>(tileOp)) {
726  rewriter.replaceAllUsesWith(tileOp->getResult(0),
727  clonedOp->getResult(0));
728  } else {
729  rewriter.modifyOpInPlace(
730  tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
731  }
732  return success();
733  };
734 
735  for (Value value : liveRange->values) {
736  // 1. Assign the tile ID to the value.
737  assignTileIdToValue(rewriter, value, tileIdAttr);
738 
739  // 2. Attempt to eliminate redundant tile copies.
740  if (succeeded(foldRedundantCopies(value)))
741  continue;
742 
743  // 3. Validate tile block arguments.
744  if (failed(validateBlockArguments(value)))
745  return failure();
746 
747  // 4. Attempt to resolve (trivial) tile ID conflicts.
748  if (failed(resolveTrivialTileConflicts(value)))
749  return failure();
750  }
751  }
752  return success();
753 }
754 
755 /// Prints live ranges alongside operation names for debugging.
756 void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
757  ArrayRef<LiveRange const *> liveRanges,
758  FunctionOpInterface function) {
759  llvm::errs() << "SME Tile Liveness: @" << function.getName()
760  << "\nKey:\nS - Start\nE - End\n| - Live\n";
761  for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
762  llvm::errs() << "^bb" << blockIdx << ":\n";
763  for (Operation &op : block.getOperations()) {
764  unsigned operationIndex = operationToIndexMap.at(&op);
765  for (LiveRange const *range : liveRanges) {
766  char liveness = ' ';
767  for (auto it = range->ranges->begin(); it != range->ranges->end();
768  ++it) {
769  if (it.start() == operationIndex)
770  liveness = (liveness == 'E' ? '|' : 'S');
771  else if (it.stop() == operationIndex)
772  liveness = (liveness == 'S' ? '|' : 'E');
773  else if (operationIndex >= it.start() && operationIndex < it.stop())
774  liveness = '|';
775  }
776  llvm::errs() << liveness;
777  }
778  llvm::errs() << ' ' << op.getName() << '\n';
779  }
780  }
781  llvm::errs() << "==========\n";
782 }
783 
784 struct TestTileAllocationPass
785  : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
786  using TestTileAllocationBase::TestTileAllocationBase;
787  void runOnOperation() override {
788  FunctionOpInterface function = getOperation();
789  if (preprocessOnly) {
790  IRRewriter rewriter(function);
791  return preprocessForTileAllocation(rewriter, function);
792  }
793  if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
794  signalPassFailure();
795  }
796 };
797 } // namespace
798 
799 LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
800  bool dumpRanges) {
801  if (function.empty()) {
802  // TODO: Also return early if the function contains no ArmSME ops?
803  return success();
804  }
805 
806  LiveRange::Allocator liveRangeAllocator;
807  IRRewriter rewriter(function.getContext());
808 
809  // 1. Preprocess the IR for tile allocation.
810  preprocessForTileAllocation(rewriter, function);
811 
812  // 2. Gather live ranges for each ArmSME tile within the function.
813  Liveness liveness(function);
814  auto operationToIndexMap = generateOperationNumbering(function);
815  auto initialLiveRanges = gatherTileLiveRanges(
816  operationToIndexMap, liveRangeAllocator, liveness, function);
817  if (initialLiveRanges.empty())
818  return success();
819 
820  if (dumpRanges) {
821  // Wrangle initial live ranges into a form suitable for printing.
822  auto nonEmpty = llvm::make_filter_range(
823  llvm::make_second_range(initialLiveRanges),
824  [&](LiveRange const &liveRange) { return !liveRange.empty(); });
825  auto initialRanges = llvm::to_vector(llvm::map_range(
826  nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
827  std::sort(initialRanges.begin(), initialRanges.end(),
828  [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
829  llvm::errs() << "\n========== Initial Live Ranges:\n";
830  dumpLiveRanges(operationToIndexMap, initialRanges, function);
831  }
832 
833  // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
834  // for tile allocation. E.g. Unify the result of an operation with its
835  // operands.
836  auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
837 
838  if (dumpRanges) {
839  llvm::errs() << "\n========== Coalesced Live Ranges:\n";
840  dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
841  }
842 
843  // 4. Allocate tile IDs to live ranges.
844  allocateTilesToLiveRanges(coalescedLiveRanges);
845 
846  // 5. Assign the tile IDs back to the ArmSME operations.
847  if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
848  coalescedLiveRanges))) {
849  return failure();
850  }
851 
852  // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
853  // users). This prevents the LLVM conversion needlessly inserting spills.
854  eraseTriviallyDeadTileOps(rewriter, function);
855  return success();
856 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:235
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator end()
Definition: Block.h:142
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
Definition: Liveness.cpp:229
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:461
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:44
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition: Utils.cpp:119
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition: Utils.cpp:139
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition: Utils.cpp:167
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
static constexpr unsigned kInMemoryTileIdBase
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition: Utils.cpp:152
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition: Utils.cpp:144
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool operator<(const Fraction &x, const Fraction &y)
Definition: Fraction.h:83
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305