MLIR 22.0.0git
TileAllocation.cpp
Go to the documentation of this file.
1//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transform allocates SME tiles at the 'func.func' op level for ArmSME
10// operations. It roughly implements a linear scan register allocator, similar
11// to the one outlined in [1], but with simplifications and assumptions made for
12// our use case. Note that this is a greedy allocator (so it may not always find
13// the most optimal allocation of tiles).
14//
15// The allocator operates at the CF dialect level. It is the responsibility of
16// users to ensure the IR has been lowered to CF before invoking the tile
17// allocator.
18//
19// The 128-bit tiles overlap with other element tiles as follows (see section
20// B2.3.2 of SME spec [2]):
21//
22// Tile Overlaps
23// ---------------------------------------------------------------------------
24// ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
25// ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
26// ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
27// ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
28// ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
29// ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
30// ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
31// ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
32// ZA0.D ZA0.Q, ZA8.Q
33// ZA1.D ZA1.Q, ZA9.Q
34// ZA2.D ZA2.Q, ZA10.Q
35// ZA3.D ZA3.Q, ZA11.Q
36// ZA4.D ZA4.Q, ZA12.Q
37// ZA5.D ZA5.Q, ZA13.Q
38// ZA6.D ZA6.Q, ZA14.Q
39// ZA7.D ZA7.Q, ZA15.Q
40//
41// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42// Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43// https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44// [2] https://developer.arm.com/documentation/ddi0616/aa
45//
46//===----------------------------------------------------------------------===//
47
55#include "llvm/ADT/IntervalMap.h"
56#include "llvm/ADT/TypeSwitch.h"
57
58namespace mlir::arm_sme {
59#define GEN_PASS_DEF_TESTTILEALLOCATION
60#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
61} // namespace mlir::arm_sme
62
63using namespace mlir;
64using namespace mlir::arm_sme;
65
66namespace {
67
68enum class TileMask : unsigned {
69 // clang-format off
70 kZA0B = 0xffff, // 1111 1111 1111 1111
71
72 kZA0H = 0xaaaa, // 1010 1010 1010 1010
73 kZA1H = 0x5555, // 0101 0101 0101 0101
74
75 kZA0S = 0x8888, // 1000 1000 1000 1000
76 kZA1S = 0x4444, // 0100 0100 0100 0100
77 kZA2S = 0x2222, // 0010 0010 0010 0010
78 kZA3S = 0x1111, // 0001 0001 0001 0001
79
80 kZA0D = 0x8080, // 1000 0000 1000 0000
81 kZA1D = 0x4040, // 0100 0000 0100 0000
82 kZA2D = 0x2020, // 0010 0000 0010 0000
83 kZA3D = 0x1010, // 0001 0000 0001 0000
84 kZA4D = 0x808, // 0000 1000 0000 1000
85 kZA5D = 0x404, // 0000 0100 0000 0100
86 kZA6D = 0x202, // 0000 0010 0000 0010
87 kZA7D = 0x101, // 0000 0001 0000 0001
88
89 kZA0Q = 0x8000, // 1000 0000 0000 0000
90 kZA1Q = 0x4000, // 0100 0000 0000 0000
91 kZA2Q = 0x2000, // 0010 0000 0000 0000
92 kZA3Q = 0x1000, // 0001 0000 0000 0000
93 kZA4Q = 0x800, // 0000 1000 0000 0000
94 kZA5Q = 0x400, // 0000 0100 0000 0000
95 kZA6Q = 0x200, // 0000 0010 0000 0000
96 kZA7Q = 0x100, // 0000 0001 0000 0000
97 kZA8Q = 0x80, // 0000 0000 1000 0000
98 kZA9Q = 0x40, // 0000 0000 0100 0000
99 kZA10Q = 0x20, // 0000 0000 0010 0000
100 kZA11Q = 0x10, // 0000 0000 0001 0000
101 kZA12Q = 0x8, // 0000 0000 0000 1000
102 kZA13Q = 0x4, // 0000 0000 0000 0100
103 kZA14Q = 0x2, // 0000 0000 0000 0010
104 kZA15Q = 0x1, // 0000 0000 0000 0001
105
106 kNone = 0x0, // 0000 0000 0000 0000
107 // clang-format on
108
110};
111
112/// Returns the set of masks relevant for the given type.
113static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
114 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
115 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
116 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
117 TileMask::kZA2S, TileMask::kZA3S};
118 static constexpr std::array ZA_D_MASKS = {
119 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
120 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
121 static constexpr std::array ZA_Q_MASKS = {
122 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
123 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
124 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
125 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
126 switch (type) {
127 case ArmSMETileType::ZAB:
128 return ZA_B_MASKS;
129 case ArmSMETileType::ZAH:
130 return ZA_H_MASKS;
131 case ArmSMETileType::ZAS:
132 return ZA_S_MASKS;
133 case ArmSMETileType::ZAD:
134 return ZA_D_MASKS;
135 case ArmSMETileType::ZAQ:
136 return ZA_Q_MASKS;
137 }
138 llvm_unreachable("unknown type in getMasks");
139}
140
141class TileAllocator {
142public:
143 /// Allocates and returns a tile ID. Fails if there are no tiles left.
144 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
145 auto masks = getMasks(tileType);
146 for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
147 if ((tilesInUse & tileMask) == TileMask::kNone) {
148 tilesInUse |= tileMask;
149 return tileId;
150 }
151 }
152 return failure();
153 }
154
155 /// Acquires a specific tile ID. Asserts the tile is initially free.
156 void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
157 TileMask tileMask = getMasks(tileType)[tileId];
158 assert((tilesInUse & tileMask) == TileMask::kNone &&
159 "cannot acquire allocated tile!");
160 tilesInUse |= tileMask;
161 }
162
163 /// Releases a previously allocated tile ID.
164 void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
165 TileMask tileMask = getMasks(tileType)[tileId];
166 assert((tilesInUse & tileMask) == tileMask &&
167 "cannot release unallocated tile!");
168 tilesInUse ^= tileMask;
169 }
170
171 /// Allocates an in-memory tile ID.
172 unsigned allocateInMemoryTileId() {
173 // Note: We never release in-memory tile IDs. We could, which may allow
174 // reusing an allocation, but as we _never_ want to spill an SME tile this
175 // is not optimized.
176 return nextInMemoryTileId++;
177 }
178
179private:
180 TileMask tilesInUse = TileMask::kNone;
181 unsigned nextInMemoryTileId = kInMemoryTileIdBase;
182};
183
184/// Add new intermediate blocks for the true and false destinations of
185/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
186/// overlaps due to copies at branches.
187///
188/// BEFORE:
189/// ```mlir
190/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
191/// ```
192///
193/// AFTER:
194/// ```mlir
195/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
196/// ^bb1_copy:
197/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
198/// ^bb2_copy:
199/// cf.br ^bb2
200/// ```
201void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
203 function.walk([&](cf::CondBranchOp condBranch) {
204 if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
205 return isValidSMETileVectorType(value.getType());
206 })) {
207 worklist.push_back(condBranch);
208 }
209 });
210
211 auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
212 rewriter.setInsertionPointToEnd(source);
213 cf::BranchOp::create(rewriter, loc, dest, args);
214 };
215
216 for (auto condBranch : worklist) {
217 auto loc = condBranch.getLoc();
218 Block *block = condBranch->getBlock();
219 auto *newTrueBranch = rewriter.splitBlock(block, block->end());
220 auto *newFalseBranch = rewriter.splitBlock(block, block->end());
221 insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
222 condBranch.getTrueDestOperands());
223 insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
224 condBranch.getFalseDestOperands());
225 rewriter.modifyOpInPlace(condBranch, [&] {
226 condBranch.getFalseDestOperandsMutable().clear();
227 condBranch.getTrueDestOperandsMutable().clear();
228 condBranch.setSuccessor(newTrueBranch, 0);
229 condBranch.setSuccessor(newFalseBranch, 1);
230 });
231 }
233
234/// Inserts tile copies at `cf.br` operations.
235///
236/// BEFORE:
237/// ```mlir
238/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
239/// ```
240///
241/// AFTER:
242/// ```mlir
243/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
244/// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
245/// ```
246void insertCopiesAtBranches(IRRewriter &rewriter,
247 FunctionOpInterface function) {
248 for (Block &block : function.getBlocks()) {
249 Operation *terminator = block.getTerminator();
250 if (!isa<cf::BranchOp>(terminator))
251 continue;
252 rewriter.setInsertionPoint(terminator);
253 for (OpOperand &operand : terminator->getOpOperands()) {
254 if (isValidSMETileVectorType(operand.get().getType())) {
255 auto copy =
256 CopyTileOp::create(rewriter, terminator->getLoc(), operand.get());
257 rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
258 }
259 }
260 }
261}
262
263/// Prepares the IR for tile allocation. It does this by first 'splitting'
264/// conditional branches (see `splitCondBranches`), then inserting tile copies
265/// at branch operations. The conditional branches are split to prevent the
266/// copies needed for them overlapping between the true and false paths of the
267/// branch (see `tile-allocation-copies.mlir` and
268/// `tile-allocation-liveness.mlir` for examples). The copies break up live
269/// ranges and ensure when moving out of SSA the semantics of the program are
270/// preserved.
271void preprocessForTileAllocation(IRRewriter &rewriter,
272 FunctionOpInterface function) {
273 splitCondBranches(rewriter, function);
274 insertCopiesAtBranches(rewriter, function);
275}
277/// A live range for a (collection of) tile values. A live range is built up of
278/// non-overlapping intervals [start, end) which represent parts of the program
279/// where a value in the range needs to be live (i.e. in an SME virtual tile).
280/// Note that as the intervals are non-overlapping all values within a live
281/// range can be allocated to the same SME virtual tile.
282struct LiveRange {
283 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
284 llvm::IntervalMapHalfOpenInfo<unsigned>>;
285 using Allocator = RangeSet::Allocator;
286 // Dummy value for the IntervalMap. Only the keys matter (the intervals).
287 static constexpr uint8_t kValidLiveRange = 0xff;
288
289 LiveRange(Allocator &allocator)
290 : ranges(std::make_unique<RangeSet>(allocator)) {}
291
292 /// Returns true if this range overlaps with `otherRange`.
293 bool overlaps(LiveRange const &otherRange) const {
294 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
295 *otherRange.ranges)
296 .valid();
297 }
298
299 /// Returns true if this range is active at `point` in the program.
300 bool overlaps(uint64_t point) const {
301 return ranges->lookup(point) == kValidLiveRange;
302 }
303
304 /// Unions this live range with `otherRange`, aborts if the ranges overlap.
305 void unionWith(LiveRange const &otherRange) {
306 for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
307 ++it)
308 ranges->insert(it.start(), it.stop(), kValidLiveRange);
309 values.set_union(otherRange.values);
310 }
311
312 /// Inserts an interval [start, end) for `value` into this range.
313 void insert(Value value, unsigned start, unsigned end) {
314 values.insert(value);
315 if (start != end)
316 ranges->insert(start, end, kValidLiveRange);
317 }
318
319 bool empty() const { return ranges->empty(); }
320 unsigned start() const { return ranges->start(); }
321 unsigned end() const { return ranges->stop(); }
322 bool operator<(LiveRange const &other) const {
323 return start() < other.start();
324 }
325
326 ArmSMETileType getTileType() const {
327 return *getSMETileType(cast<VectorType>(values[0].getType()));
328 }
329
330 /// The values contained in this live range.
331 SetVector<Value> values;
332
333 /// A set of (non-overlapping) intervals that mark where any value in `values`
334 /// is live.
335 std::unique_ptr<RangeSet> ranges;
336
337 /// The tile ID (or none) assigned to this live range.
338 std::optional<unsigned> tileId;
339};
340
341/// Number operations within a function to allow computing live ranges.
342/// Operations are numbered consecutively wihin blocks, and the blocks are
343/// topologically sorted (using forward edges). This function is only correct if
344/// all ArmSME have been converted to CF (which is asserted).
346generateOperationNumbering(FunctionOpInterface function) {
347 unsigned index = 0;
348 SetVector<Block *> blocks =
349 getBlocksSortedByDominance(function.getFunctionBody());
350 DenseMap<Operation *, unsigned> operationToIndexMap;
351 for (Block *block : blocks) {
352 index++; // We want block args to have their own number.
353 for (Operation &op : block->getOperations()) {
354#ifndef NDEBUG
355 op.walk([&](ArmSMETileOpInterface nestedOp) {
356 assert(&op == nestedOp.getOperation() &&
357 "ArmSME tile allocation does not support nested regions");
358 });
359#endif
360 operationToIndexMap.try_emplace(&op, index++);
361 }
362 }
363 return operationToIndexMap;
364}
365
366/// Gather live ranges for SME tiles from the MLIR liveness analysis.
368gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
369 LiveRange::Allocator &liveRangeAllocator,
370 Liveness &liveness, FunctionOpInterface function) {
371 assert(!operationToIndexMap.empty() && "expected operation numbering");
373 /// Defines or updates a live range for an SME tile value. Live-ins may update
374 /// an existing live range (rather than define a new one). Note: If
375 /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
376 /// the block.
377 auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
378 LivenessBlockInfo const &livenessInfo,
379 bool liveAtBlockEntry = false) {
380 if (!isValidSMETileVectorType(value.getType()))
381 return;
382 // Find or create a live range for `value`.
383 auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
384 LiveRange &valueLiveRange = it->second;
385 auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
386 // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
387 unsigned startOpIdx =
388 operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
389 unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
390 valueLiveRange.insert(value, startOpIdx, endOpIdx);
391 };
392
393 for (Block &block : function.getBlocks()) {
394 LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
395 // Handle block arguments:
396 for (Value argument : block.getArguments())
397 defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
398 /*liveAtBlockEntry=*/true);
399 // Handle live-ins:
400 for (Value liveIn : livenessInfo->in())
401 defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
402 /*liveAtBlockEntry=*/true);
403 // Handle new definitions:
404 for (Operation &op : block) {
405 for (Value result : op.getResults())
406 defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
407 }
408 }
409
410 return liveRanges;
411}
412
413/// Iterate over all predecessor tile values to a (tile) block argument.
414static void forEachPredecessorTileValue(BlockArgument blockArg,
415 function_ref<void(Value)> callback) {
416 Block *block = blockArg.getOwner();
417 unsigned argNumber = blockArg.getArgNumber();
418 for (Block *pred : block->getPredecessors()) {
419 TypeSwitch<Operation *>(pred->getTerminator())
420 .Case<cf::BranchOp>([&](auto branch) {
421 Value predecessorOperand = branch.getDestOperands()[argNumber];
422 callback(predecessorOperand);
423 })
424 .Case<cf::CondBranchOp>([&](auto condBranch) {
425 if (condBranch.getFalseDest() == block) {
426 Value predecessorOperand =
427 condBranch.getFalseDestOperands()[argNumber];
428 callback(predecessorOperand);
429 }
430 if (condBranch.getTrueDest() == block) {
431 Value predecessorOperand =
432 condBranch.getTrueDestOperands()[argNumber];
433 callback(predecessorOperand);
434 }
435 });
436 }
437}
438
439/// Coalesce live ranges where it would prevent unnecessary tile moves.
441coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
443 for (auto &[value, liveRange] : initialLiveRanges) {
444 liveRanges.insert({value, &liveRange});
445 }
446
447 // Merge the live ranges of values `a` and `b` into one (if they do not
448 // overlap). After this, the values `a` and `b` will both point to the same
449 // live range (which will contain multiple values).
450 auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
451 LiveRange *aLiveRange = liveRanges.at(a);
452 LiveRange *bLiveRange = liveRanges.at(b);
453 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
454 aLiveRange->unionWith(*bLiveRange);
455 for (Value value : bLiveRange->values)
456 liveRanges[value] = aLiveRange;
457 }
458 };
459
460 // Merge the live ranges of new definitions with their tile operands.
461 auto unifyDefinitionsWithOperands = [&](Value value) {
462 auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
463 if (!armSMEOp)
464 return;
465 for (auto operand : armSMEOp->getOperands()) {
466 if (isValidSMETileVectorType(operand.getType()))
467 mergeValuesIfNonOverlapping(value, operand);
468 }
469 };
470
471 // Merge the live ranges of block arguments with their predecessors.
472 auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
473 auto blockArg = dyn_cast<BlockArgument>(value);
474 if (!blockArg)
475 return;
476 forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
477 mergeValuesIfNonOverlapping(blockArg, predecessorTile);
478 });
479 };
480
481 auto applyRule = [&](auto rule) {
482 llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
483 };
484
485 // Unify as many live ranges as we can. This prevents unnecessary moves.
486 applyRule(unifyBlockArgumentsWithPredecessors);
487 applyRule(unifyDefinitionsWithOperands);
488
489 // Remove duplicate live range entries.
490 SetVector<LiveRange *> uniqueLiveRanges;
491 for (auto [_, liveRange] : liveRanges) {
492 if (!liveRange->empty())
493 uniqueLiveRanges.insert(liveRange);
494 }
495
496 // Sort the new live ranges by starting point (ready for tile allocation).
497 auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
498 llvm::sort(coalescedLiveRanges,
499 [](LiveRange *a, LiveRange *b) { return *a < *b; });
500 return std::move(coalescedLiveRanges);
501}
502
503/// Choose a live range to spill (via some heuristics). This picks either a live
504/// range from `overlappingRanges`, or the new live range `newRange`.
505template <typename OverlappingRangesIterator>
506LiveRange *
507chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
508 LiveRange *newRange) {
509 // Heuristic: Spill trivially copyable operations (usually free).
510 auto isTrivialSpill = [&](LiveRange &allocatedRange) {
511 return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
512 newRange->getTileType()) &&
513 allocatedRange.values.size() == 1 &&
515 allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
516 };
517 if (isTrivialSpill(*newRange))
518 return newRange;
519 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
520 if (trivialSpill != overlappingRanges.end())
521 return &*trivialSpill;
522
523 // Heuristic: Spill the range that ends last (with a compatible tile type).
524 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
525 return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
526 a.end() < b.end();
527 };
528 LiveRange &latestEndingLiveRange =
529 *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
530 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
531 return &latestEndingLiveRange;
532 return newRange;
533}
534
535/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
536void allocateTilesToLiveRanges(
537 ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
538 TileAllocator tileAllocator;
539 // `activeRanges` = Live ranges that need to be in a tile at the
540 // `currentPoint` in the program.
541 SetVector<LiveRange *> activeRanges;
542 // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
543 // at the `currentPoint` in the program but could become active again later.
544 // An inactive section of a live range can be seen as a 'hole' in the live
545 // range, where it is possible to reuse the live range's tile ID _before_ it
546 // has ended. By identifying 'holes', the allocator can reuse tiles more
547 // often, which helps avoid costly tile spills.
548 SetVector<LiveRange *> inactiveRanges;
549 for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
550 auto currentPoint = nextRange->start();
551 // 1. Update the `activeRanges` at `currentPoint`.
552 activeRanges.remove_if([&](LiveRange *activeRange) {
553 // Check for live ranges that have expired.
554 if (activeRange->end() <= currentPoint) {
555 tileAllocator.releaseTileId(activeRange->getTileType(),
556 *activeRange->tileId);
557 return true;
558 }
559 // Check for live ranges that have become inactive.
560 if (!activeRange->overlaps(currentPoint)) {
561 tileAllocator.releaseTileId(activeRange->getTileType(),
562 *activeRange->tileId);
563 inactiveRanges.insert(activeRange);
564 return true;
565 }
566 return false;
567 });
568 // 2. Update the `inactiveRanges` at `currentPoint`.
569 inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
570 // Check for live ranges that have expired.
571 if (inactiveRange->end() <= currentPoint) {
572 return true;
573 }
574 // Check for live ranges that have become active.
575 if (inactiveRange->overlaps(currentPoint)) {
576 tileAllocator.acquireTileId(inactiveRange->getTileType(),
577 *inactiveRange->tileId);
578 activeRanges.insert(inactiveRange);
579 return true;
580 }
581 return false;
582 });
583
584 // 3. Collect inactive live ranges that overlap with the new live range.
585 // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
586 // whereas this checks if there is an overlap at any future point too.
587 SmallVector<LiveRange *> overlappingInactiveRanges;
588 for (LiveRange *inactiveRange : inactiveRanges) {
589 if (inactiveRange->overlaps(*nextRange)) {
590 // We need to reserve the tile IDs of overlapping inactive ranges to
591 // prevent two (overlapping) live ranges from getting the same tile ID.
592 tileAllocator.acquireTileId(inactiveRange->getTileType(),
593 *inactiveRange->tileId);
594 overlappingInactiveRanges.push_back(inactiveRange);
595 }
596 }
597
598 // 4. Allocate a tile ID to `nextRange`.
599 auto rangeTileType = nextRange->getTileType();
600 auto tileId = tileAllocator.allocateTileId(rangeTileType);
601 if (succeeded(tileId)) {
602 nextRange->tileId = *tileId;
603 } else {
604 // Create an iterator over all overlapping live ranges.
605 auto allOverlappingRanges = llvm::concat<LiveRange>(
606 llvm::make_pointee_range(activeRanges.getArrayRef()),
607 llvm::make_pointee_range(overlappingInactiveRanges));
608 // Choose an overlapping live range to spill.
609 LiveRange *rangeToSpill =
610 chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
611 if (rangeToSpill != nextRange) {
612 // Spill an (in)active live range (so release its tile ID first).
613 tileAllocator.releaseTileId(rangeToSpill->getTileType(),
614 *rangeToSpill->tileId);
615 // This will always succeed after a spill (of an active live range).
616 nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
617 // Remove the live range from the active/inactive sets.
618 if (!activeRanges.remove(rangeToSpill)) {
619 bool removed = inactiveRanges.remove(rangeToSpill);
620 assert(removed && "expected a range to be removed!");
621 (void)removed;
622 }
623 }
624 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
625 }
626
627 // 5. Insert the live range into the active ranges.
628 if (nextRange->tileId < kInMemoryTileIdBase)
629 activeRanges.insert(nextRange);
630
631 // 6. Release tiles reserved for inactive live ranges (in step 3).
632 for (LiveRange *range : overlappingInactiveRanges) {
633 if (*range->tileId < kInMemoryTileIdBase)
634 tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
635 }
636 }
637}
638
639/// Assigns a tile ID to an MLIR value.
640void assignTileIdToValue(IRRewriter &rewriter, Value value,
641 IntegerAttr tileIdAttr) {
642 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
643 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
644 for (Operation *user : value.getUsers()) {
645 if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
646 // Ensure ArmSME ops that don't produce a value still get a tile ID.
647 if (!hasTileResult(tileOp))
648 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
649 }
650 }
651}
652
653/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
654LogicalResult assignTileIdsAndResolveTrivialConflicts(
655 IRRewriter &rewriter, FunctionOpInterface function,
656 ArrayRef<LiveRange *> allocatedLiveRanges) {
657 for (LiveRange const *liveRange : allocatedLiveRanges) {
658 auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
659 auto isAllocatedToSameTile = [&](Value value) {
660 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
661 tileOp && tileOp.getTileId() == tileIdAttr)
662 return true;
663 return liveRange->values.contains(value);
664 };
665
666 /// Eliminates copies where the operand has the same tile ID.
667 auto foldRedundantCopies = [&](Value value) -> LogicalResult {
668 auto copyOp = value.getDefiningOp<CopyTileOp>();
669 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
670 return failure();
671 rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
672 return success();
673 };
674
675 /// Validates each predecessor to a tile block argument has been assigned
676 /// the same tile ID.
677 auto validateBlockArguments = [&](Value value) {
678 auto blockArg = dyn_cast<BlockArgument>(value);
679 if (!blockArg) {
680 // Not a block argument (nothing to validate).
681 return success();
682 }
683 bool tileMismatch = false;
684 forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
685 if (tileMismatch)
686 return;
687 if (!isAllocatedToSameTile(predecessorTile)) {
688 blockArg.getOwner()->getParentOp()->emitOpError(
689 "block argument not allocated to the same SME virtial tile as "
690 "predecessors");
691 tileMismatch = true;
692 }
693 });
694 return success(/*isSuccess=*/!tileMismatch);
695 };
696
697 /// Attempts to resolve (trivial) tile ID conflicts.
698 auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
699 auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
700 OpOperand *tileOperand = getTileOpOperand(tileOp);
701 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
702 // Operand already allocated to the correct tile.
703 // No conflict to resolve.
704 return success();
705 }
706 auto operandTileOp =
707 tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
708 if (!isTriviallyCloneableTileOp(operandTileOp)) {
709 auto error =
710 tileOp.emitOpError("tile operand allocated to different SME "
711 "virtial tile (move required)");
712 error.attachNote(tileOperand->get().getLoc())
713 << "tile operand is: " << tileOperand->get();
714 return error;
715 }
716 // Cloning prevents a move/spill (though may require recomputation).
717 rewriter.setInsertionPoint(tileOp);
718 auto clonedOp = operandTileOp.clone();
719 rewriter.modifyOpInPlace(clonedOp,
720 [&] { clonedOp.setTileId(tileOp.getTileId()); });
721 rewriter.insert(clonedOp);
722 if (isa<CopyTileOp>(tileOp)) {
723 rewriter.replaceAllUsesWith(tileOp->getResult(0),
724 clonedOp->getResult(0));
725 } else {
726 rewriter.modifyOpInPlace(
727 tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
728 }
729 return success();
730 };
731
732 for (Value value : liveRange->values) {
733 // 1. Assign the tile ID to the value.
734 assignTileIdToValue(rewriter, value, tileIdAttr);
735
736 // 2. Attempt to eliminate redundant tile copies.
737 if (succeeded(foldRedundantCopies(value)))
738 continue;
739
740 // 3. Validate tile block arguments.
741 if (failed(validateBlockArguments(value)))
742 return failure();
743
744 // 4. Attempt to resolve (trivial) tile ID conflicts.
745 if (failed(resolveTrivialTileConflicts(value)))
746 return failure();
747 }
748 }
749 return success();
750}
751
752/// Prints live ranges alongside operation names for debugging.
753void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
755 FunctionOpInterface function) {
756 llvm::errs() << "SME Tile Liveness: @" << function.getName()
757 << "\nKey:\nS - Start\nE - End\n| - Live\n";
758 for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
759 llvm::errs() << "^bb" << blockIdx << ":\n";
760 for (Operation &op : block.getOperations()) {
761 unsigned operationIndex = operationToIndexMap.at(&op);
762 for (LiveRange const *range : liveRanges) {
763 char liveness = ' ';
764 for (auto it = range->ranges->begin(); it != range->ranges->end();
765 ++it) {
766 if (it.start() == operationIndex)
767 liveness = (liveness == 'E' ? '|' : 'S');
768 else if (it.stop() == operationIndex)
769 liveness = (liveness == 'S' ? '|' : 'E');
770 else if (operationIndex >= it.start() && operationIndex < it.stop())
771 liveness = '|';
772 }
773 llvm::errs() << liveness;
774 }
775 llvm::errs() << ' ' << op.getName() << '\n';
776 }
777 }
778 llvm::errs() << "==========\n";
779}
780
781struct TestTileAllocationPass
782 : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
783 using TestTileAllocationBase::TestTileAllocationBase;
784 void runOnOperation() override {
785 FunctionOpInterface function = getOperation();
786 if (preprocessOnly) {
787 IRRewriter rewriter(function);
788 return preprocessForTileAllocation(rewriter, function);
789 }
790 if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
791 signalPassFailure();
792 }
793};
794} // namespace
795
796LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
797 bool dumpRanges) {
798 if (function.empty()) {
799 // TODO: Also return early if the function contains no ArmSME ops?
800 return success();
801 }
802
803 LiveRange::Allocator liveRangeAllocator;
804 IRRewriter rewriter(function.getContext());
805
806 // 1. Preprocess the IR for tile allocation.
807 preprocessForTileAllocation(rewriter, function);
808
809 // 2. Gather live ranges for each ArmSME tile within the function.
810 Liveness liveness(function);
811 auto operationToIndexMap = generateOperationNumbering(function);
812 auto initialLiveRanges = gatherTileLiveRanges(
813 operationToIndexMap, liveRangeAllocator, liveness, function);
814 if (initialLiveRanges.empty())
815 return success();
816
817 if (dumpRanges) {
818 // Wrangle initial live ranges into a form suitable for printing.
819 auto nonEmpty = llvm::make_filter_range(
820 llvm::make_second_range(initialLiveRanges),
821 [&](LiveRange const &liveRange) { return !liveRange.empty(); });
822 auto initialRanges = llvm::to_vector(llvm::map_range(
823 nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
824 llvm::sort(initialRanges,
825 [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
826 llvm::errs() << "\n========== Initial Live Ranges:\n";
827 dumpLiveRanges(operationToIndexMap, initialRanges, function);
828 }
829
830 // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
831 // for tile allocation. E.g. Unify the result of an operation with its
832 // operands.
833 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
834
835 if (dumpRanges) {
836 llvm::errs() << "\n========== Coalesced Live Ranges:\n";
837 dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
838 }
839
840 // 4. Allocate tile IDs to live ranges.
841 allocateTilesToLiveRanges(coalescedLiveRanges);
842
843 // 5. Assign the tile IDs back to the ArmSME operations.
844 if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
845 coalescedLiveRanges))) {
846 return failure();
847 }
848
849 // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
850 // users). This prevents the LLVM conversion needlessly inserting spills.
851 eraseTriviallyDeadTileOps(rewriter, function);
852 return success();
853}
return success()
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:240
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator end()
Definition Block.h:144
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents liveness information on block level.
Definition Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition Liveness.h:110
Represents an analysis for computing liveness information from a given top-level operation.
Definition Liveness.h:47
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
Definition Liveness.cpp:225
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
This class represents an operand of an operation.
Definition Value.h:257
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
mlir::IntegerAttr getTileId()
Returns the tile ID assigned to this operation.
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition Utils.cpp:58
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition Utils.cpp:133
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition Utils.cpp:153
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition Utils.cpp:181
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
static constexpr unsigned kInMemoryTileIdBase
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition Utils.cpp:166
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition Utils.cpp:158
bool operator<(const Fraction &x, const Fraction &y)
Definition Fraction.h:83
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152