MLIR 23.0.0git
TileAllocation.cpp
Go to the documentation of this file.
1//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transform allocates SME tiles at the 'func.func' op level for ArmSME
10// operations. It roughly implements a linear scan register allocator, similar
11// to the one outlined in [1], but with simplifications and assumptions made for
12// our use case. Note that this is a greedy allocator (so it may not always find
13// the most optimal allocation of tiles).
14//
15// The allocator operates at the CF dialect level. It is the responsibility of
16// users to ensure the IR has been lowered to CF before invoking the tile
17// allocator.
18//
19// The 128-bit tiles overlap with other element tiles as follows (see section
20// B2.3.2 of SME spec [2]):
21//
22// Tile Overlaps
23// ---------------------------------------------------------------------------
24// ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
25// ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
26// ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
27// ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
28// ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
29// ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
30// ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
31// ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
32// ZA0.D ZA0.Q, ZA8.Q
33// ZA1.D ZA1.Q, ZA9.Q
34// ZA2.D ZA2.Q, ZA10.Q
35// ZA3.D ZA3.Q, ZA11.Q
36// ZA4.D ZA4.Q, ZA12.Q
37// ZA5.D ZA5.Q, ZA13.Q
38// ZA6.D ZA6.Q, ZA14.Q
39// ZA7.D ZA7.Q, ZA15.Q
40//
41// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42// Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43// https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44// [2] https://developer.arm.com/documentation/ddi0616/aa
45//
46//===----------------------------------------------------------------------===//
47
55#include "llvm/ADT/IntervalMap.h"
56#include "llvm/ADT/SmallVectorExtras.h"
57#include "llvm/ADT/TypeSwitch.h"
58
59namespace mlir::arm_sme {
60#define GEN_PASS_DEF_TESTTILEALLOCATION
61#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
62} // namespace mlir::arm_sme
63
64using namespace mlir;
65using namespace mlir::arm_sme;
66
67namespace {
68
69enum class TileMask : unsigned {
70 // clang-format off
71 kZA0B = 0xffff, // 1111 1111 1111 1111
72
73 kZA0H = 0xaaaa, // 1010 1010 1010 1010
74 kZA1H = 0x5555, // 0101 0101 0101 0101
75
76 kZA0S = 0x8888, // 1000 1000 1000 1000
77 kZA1S = 0x4444, // 0100 0100 0100 0100
78 kZA2S = 0x2222, // 0010 0010 0010 0010
79 kZA3S = 0x1111, // 0001 0001 0001 0001
80
81 kZA0D = 0x8080, // 1000 0000 1000 0000
82 kZA1D = 0x4040, // 0100 0000 0100 0000
83 kZA2D = 0x2020, // 0010 0000 0010 0000
84 kZA3D = 0x1010, // 0001 0000 0001 0000
85 kZA4D = 0x808, // 0000 1000 0000 1000
86 kZA5D = 0x404, // 0000 0100 0000 0100
87 kZA6D = 0x202, // 0000 0010 0000 0010
88 kZA7D = 0x101, // 0000 0001 0000 0001
89
90 kZA0Q = 0x8000, // 1000 0000 0000 0000
91 kZA1Q = 0x4000, // 0100 0000 0000 0000
92 kZA2Q = 0x2000, // 0010 0000 0000 0000
93 kZA3Q = 0x1000, // 0001 0000 0000 0000
94 kZA4Q = 0x800, // 0000 1000 0000 0000
95 kZA5Q = 0x400, // 0000 0100 0000 0000
96 kZA6Q = 0x200, // 0000 0010 0000 0000
97 kZA7Q = 0x100, // 0000 0001 0000 0000
98 kZA8Q = 0x80, // 0000 0000 1000 0000
99 kZA9Q = 0x40, // 0000 0000 0100 0000
100 kZA10Q = 0x20, // 0000 0000 0010 0000
101 kZA11Q = 0x10, // 0000 0000 0001 0000
102 kZA12Q = 0x8, // 0000 0000 0000 1000
103 kZA13Q = 0x4, // 0000 0000 0000 0100
104 kZA14Q = 0x2, // 0000 0000 0000 0010
105 kZA15Q = 0x1, // 0000 0000 0000 0001
106
107 kNone = 0x0, // 0000 0000 0000 0000
108 // clang-format on
109
111};
112
113/// Returns the set of masks relevant for the given type.
114static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
115 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
116 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
117 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
118 TileMask::kZA2S, TileMask::kZA3S};
119 static constexpr std::array ZA_D_MASKS = {
120 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
121 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
122 static constexpr std::array ZA_Q_MASKS = {
123 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
124 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
125 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
126 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
127 switch (type) {
128 case ArmSMETileType::ZAB:
129 return ZA_B_MASKS;
130 case ArmSMETileType::ZAH:
131 return ZA_H_MASKS;
132 case ArmSMETileType::ZAS:
133 return ZA_S_MASKS;
134 case ArmSMETileType::ZAD:
135 return ZA_D_MASKS;
136 case ArmSMETileType::ZAQ:
137 return ZA_Q_MASKS;
138 }
139 llvm_unreachable("unknown type in getMasks");
140}
141
142class TileAllocator {
143public:
144 /// Allocates and returns a tile ID. Fails if there are no tiles left.
145 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
146 auto masks = getMasks(tileType);
147 for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
148 if ((tilesInUse & tileMask) == TileMask::kNone) {
149 tilesInUse |= tileMask;
150 return tileId;
151 }
152 }
153 return failure();
154 }
155
156 /// Acquires a specific tile ID. Asserts the tile is initially free.
157 void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
158 TileMask tileMask = getMasks(tileType)[tileId];
159 assert((tilesInUse & tileMask) == TileMask::kNone &&
160 "cannot acquire allocated tile!");
161 tilesInUse |= tileMask;
162 }
163
164 /// Releases a previously allocated tile ID.
165 void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
166 TileMask tileMask = getMasks(tileType)[tileId];
167 assert((tilesInUse & tileMask) == tileMask &&
168 "cannot release unallocated tile!");
169 tilesInUse ^= tileMask;
170 }
171
172 /// Allocates an in-memory tile ID.
173 unsigned allocateInMemoryTileId() {
174 // Note: We never release in-memory tile IDs. We could, which may allow
175 // reusing an allocation, but as we _never_ want to spill an SME tile this
176 // is not optimized.
177 return nextInMemoryTileId++;
178 }
179
180private:
181 TileMask tilesInUse = TileMask::kNone;
182 unsigned nextInMemoryTileId = kInMemoryTileIdBase;
183};
184
185/// Add new intermediate blocks for the true and false destinations of
186/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
187/// overlaps due to copies at branches.
188///
189/// BEFORE:
190/// ```mlir
191/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
192/// ```
193///
194/// AFTER:
195/// ```mlir
196/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
197/// ^bb1_copy:
198/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
199/// ^bb2_copy:
200/// cf.br ^bb2
201/// ```
202void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
204 function.walk([&](cf::CondBranchOp condBranch) {
205 if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
206 return isValidSMETileVectorType(value.getType());
207 })) {
208 worklist.push_back(condBranch);
209 }
210 });
211
212 auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
213 rewriter.setInsertionPointToEnd(source);
214 cf::BranchOp::create(rewriter, loc, dest, args);
215 };
216
217 for (auto condBranch : worklist) {
218 auto loc = condBranch.getLoc();
219 Block *block = condBranch->getBlock();
220 auto *newTrueBranch = rewriter.splitBlock(block, block->end());
221 auto *newFalseBranch = rewriter.splitBlock(block, block->end());
222 insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
223 condBranch.getTrueDestOperands());
224 insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
225 condBranch.getFalseDestOperands());
226 rewriter.modifyOpInPlace(condBranch, [&] {
227 condBranch.getFalseDestOperandsMutable().clear();
228 condBranch.getTrueDestOperandsMutable().clear();
229 condBranch.setSuccessor(newTrueBranch, 0);
230 condBranch.setSuccessor(newFalseBranch, 1);
231 });
232 }
233}
234
235/// Inserts tile copies at `cf.br` operations.
236///
237/// BEFORE:
238/// ```mlir
239/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
240/// ```
241///
242/// AFTER:
243/// ```mlir
244/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
245/// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
246/// ```
247void insertCopiesAtBranches(IRRewriter &rewriter,
248 FunctionOpInterface function) {
249 for (Block &block : function.getBlocks()) {
250 Operation *terminator = block.getTerminator();
251 if (!isa<cf::BranchOp>(terminator))
252 continue;
253 rewriter.setInsertionPoint(terminator);
254 for (OpOperand &operand : terminator->getOpOperands()) {
255 if (isValidSMETileVectorType(operand.get().getType())) {
256 auto copy =
257 CopyTileOp::create(rewriter, terminator->getLoc(), operand.get());
258 rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
259 }
260 }
261 }
262}
263
264/// Prepares the IR for tile allocation. It does this by first 'splitting'
265/// conditional branches (see `splitCondBranches`), then inserting tile copies
266/// at branch operations. The conditional branches are split to prevent the
267/// copies needed for them overlapping between the true and false paths of the
268/// branch (see `tile-allocation-copies.mlir` and
269/// `tile-allocation-liveness.mlir` for examples). The copies break up live
270/// ranges and ensure when moving out of SSA the semantics of the program are
271/// preserved.
272void preprocessForTileAllocation(IRRewriter &rewriter,
273 FunctionOpInterface function) {
274 splitCondBranches(rewriter, function);
275 insertCopiesAtBranches(rewriter, function);
276}
277
278/// A live range for a (collection of) tile values. A live range is built up of
279/// non-overlapping intervals [start, end) which represent parts of the program
280/// where a value in the range needs to be live (i.e. in an SME virtual tile).
281/// Note that as the intervals are non-overlapping all values within a live
282/// range can be allocated to the same SME virtual tile.
283struct LiveRange {
284 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
285 llvm::IntervalMapHalfOpenInfo<unsigned>>;
286 using Allocator = RangeSet::Allocator;
287 // Dummy value for the IntervalMap. Only the keys matter (the intervals).
288 static constexpr uint8_t kValidLiveRange = 0xff;
289
290 LiveRange(Allocator &allocator)
291 : ranges(std::make_unique<RangeSet>(allocator)) {}
292
293 /// Returns true if this range overlaps with `otherRange`.
294 bool overlaps(LiveRange const &otherRange) const {
295 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
296 *otherRange.ranges)
297 .valid();
298 }
299
300 /// Returns true if this range is active at `point` in the program.
301 bool overlaps(uint64_t point) const {
302 return ranges->lookup(point) == kValidLiveRange;
303 }
304
305 /// Unions this live range with `otherRange`, aborts if the ranges overlap.
306 void unionWith(LiveRange const &otherRange) {
307 for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
308 ++it)
309 ranges->insert(it.start(), it.stop(), kValidLiveRange);
310 values.set_union(otherRange.values);
311 }
312
313 /// Inserts an interval [start, end) for `value` into this range.
314 void insert(Value value, unsigned start, unsigned end) {
315 values.insert(value);
316 if (start != end)
317 ranges->insert(start, end, kValidLiveRange);
318 }
319
320 bool empty() const { return ranges->empty(); }
321 unsigned start() const { return ranges->start(); }
322 unsigned end() const { return ranges->stop(); }
323 bool operator<(LiveRange const &other) const {
324 return start() < other.start();
325 }
326
327 ArmSMETileType getTileType() const {
328 return *getSMETileType(cast<VectorType>(values[0].getType()));
329 }
330
331 /// The values contained in this live range.
332 SetVector<Value> values;
333
334 /// A set of (non-overlapping) intervals that mark where any value in `values`
335 /// is live.
336 std::unique_ptr<RangeSet> ranges;
337
338 /// The tile ID (or none) assigned to this live range.
339 std::optional<unsigned> tileId;
340};
341
342/// Number operations within a function to allow computing live ranges.
343/// Operations are numbered consecutively wihin blocks, and the blocks are
344/// topologically sorted (using forward edges). This function is only correct if
345/// all ArmSME have been converted to CF (which is asserted).
347generateOperationNumbering(FunctionOpInterface function) {
348 unsigned index = 0;
349 SetVector<Block *> blocks =
350 getBlocksSortedByDominance(function.getFunctionBody());
351 DenseMap<Operation *, unsigned> operationToIndexMap;
352 for (Block *block : blocks) {
353 index++; // We want block args to have their own number.
354 for (Operation &op : block->getOperations()) {
355#ifndef NDEBUG
356 op.walk([&](ArmSMETileOpInterface nestedOp) {
357 assert(&op == nestedOp.getOperation() &&
358 "ArmSME tile allocation does not support nested regions");
359 });
360#endif
361 operationToIndexMap.try_emplace(&op, index++);
362 }
363 }
364 return operationToIndexMap;
365}
366
367/// Gather live ranges for SME tiles from the MLIR liveness analysis.
369gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
370 LiveRange::Allocator &liveRangeAllocator,
371 Liveness &liveness, FunctionOpInterface function) {
372 assert(!operationToIndexMap.empty() && "expected operation numbering");
374 /// Defines or updates a live range for an SME tile value. Live-ins may update
375 /// an existing live range (rather than define a new one). Note: If
376 /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
377 /// the block.
378 auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
379 LivenessBlockInfo const &livenessInfo,
380 bool liveAtBlockEntry = false) {
381 if (!isValidSMETileVectorType(value.getType()))
382 return;
383 // Find or create a live range for `value`.
384 auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
385 LiveRange &valueLiveRange = it->second;
386 auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
387 // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
388 unsigned startOpIdx =
389 operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
390 unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
391 valueLiveRange.insert(value, startOpIdx, endOpIdx);
392 };
393
394 for (Block &block : function.getBlocks()) {
395 LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
396 // Handle block arguments:
397 for (Value argument : block.getArguments())
398 defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
399 /*liveAtBlockEntry=*/true);
400 // Handle live-ins:
401 for (Value liveIn : livenessInfo->in())
402 defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
403 /*liveAtBlockEntry=*/true);
404 // Handle new definitions:
405 for (Operation &op : block) {
406 for (Value result : op.getResults())
407 defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
408 }
409 }
410
411 return liveRanges;
412}
413
414/// Iterate over all predecessor tile values to a (tile) block argument.
415static void forEachPredecessorTileValue(BlockArgument blockArg,
416 function_ref<void(Value)> callback) {
417 Block *block = blockArg.getOwner();
418 unsigned argNumber = blockArg.getArgNumber();
419 for (Block *pred : block->getPredecessors()) {
420 TypeSwitch<Operation *>(pred->getTerminator())
421 .Case([&](cf::BranchOp branch) {
422 Value predecessorOperand = branch.getDestOperands()[argNumber];
423 callback(predecessorOperand);
424 })
425 .Case([&](cf::CondBranchOp condBranch) {
426 if (condBranch.getFalseDest() == block) {
427 Value predecessorOperand =
428 condBranch.getFalseDestOperands()[argNumber];
429 callback(predecessorOperand);
430 }
431 if (condBranch.getTrueDest() == block) {
432 Value predecessorOperand =
433 condBranch.getTrueDestOperands()[argNumber];
434 callback(predecessorOperand);
435 }
436 });
437 }
438}
439
440/// Coalesce live ranges where it would prevent unnecessary tile moves.
442coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
444 for (auto &[value, liveRange] : initialLiveRanges) {
445 liveRanges.insert({value, &liveRange});
446 }
447
448 // Merge the live ranges of values `a` and `b` into one (if they do not
449 // overlap). After this, the values `a` and `b` will both point to the same
450 // live range (which will contain multiple values).
451 auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
452 LiveRange *aLiveRange = liveRanges.at(a);
453 LiveRange *bLiveRange = liveRanges.at(b);
454 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
455 aLiveRange->unionWith(*bLiveRange);
456 for (Value value : bLiveRange->values)
457 liveRanges[value] = aLiveRange;
458 }
459 };
460
461 // Merge the live ranges of new definitions with their tile operands.
462 auto unifyDefinitionsWithOperands = [&](Value value) {
463 auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
464 if (!armSMEOp)
465 return;
466 for (auto operand : armSMEOp->getOperands()) {
467 if (isValidSMETileVectorType(operand.getType()))
468 mergeValuesIfNonOverlapping(value, operand);
469 }
470 };
471
472 // Merge the live ranges of block arguments with their predecessors.
473 auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
474 auto blockArg = dyn_cast<BlockArgument>(value);
475 if (!blockArg)
476 return;
477 forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
478 mergeValuesIfNonOverlapping(blockArg, predecessorTile);
479 });
480 };
481
482 auto applyRule = [&](auto rule) {
483 llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
484 };
485
486 // Unify as many live ranges as we can. This prevents unnecessary moves.
487 applyRule(unifyBlockArgumentsWithPredecessors);
488 applyRule(unifyDefinitionsWithOperands);
489
490 // Remove duplicate live range entries.
491 SetVector<LiveRange *> uniqueLiveRanges;
492 for (auto [_, liveRange] : liveRanges) {
493 if (!liveRange->empty())
494 uniqueLiveRanges.insert(liveRange);
495 }
496
497 // Sort the new live ranges by starting point (ready for tile allocation).
498 auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
499 llvm::sort(coalescedLiveRanges,
500 [](LiveRange *a, LiveRange *b) { return *a < *b; });
501 return std::move(coalescedLiveRanges);
502}
503
504/// Choose a live range to spill (via some heuristics). This picks either a live
505/// range from `overlappingRanges`, or the new live range `newRange`.
506template <typename OverlappingRangesIterator>
507LiveRange *
508chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
509 LiveRange *newRange) {
510 // Heuristic: Spill trivially copyable operations (usually free).
511 auto isTrivialSpill = [&](LiveRange &allocatedRange) {
512 return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
513 newRange->getTileType()) &&
514 allocatedRange.values.size() == 1 &&
516 allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
517 };
518 if (isTrivialSpill(*newRange))
519 return newRange;
520 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
521 if (trivialSpill != overlappingRanges.end())
522 return &*trivialSpill;
523
524 // Heuristic: Spill the range that ends last (with a compatible tile type).
525 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
526 return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
527 a.end() < b.end();
528 };
529 LiveRange &latestEndingLiveRange =
530 *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
531 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
532 return &latestEndingLiveRange;
533 return newRange;
534}
535
536/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
537void allocateTilesToLiveRanges(
538 ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
539 TileAllocator tileAllocator;
540 // `activeRanges` = Live ranges that need to be in a tile at the
541 // `currentPoint` in the program.
542 SetVector<LiveRange *> activeRanges;
543 // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
544 // at the `currentPoint` in the program but could become active again later.
545 // An inactive section of a live range can be seen as a 'hole' in the live
546 // range, where it is possible to reuse the live range's tile ID _before_ it
547 // has ended. By identifying 'holes', the allocator can reuse tiles more
548 // often, which helps avoid costly tile spills.
549 SetVector<LiveRange *> inactiveRanges;
550 for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
551 auto currentPoint = nextRange->start();
552 // 1. Update the `activeRanges` at `currentPoint`.
553 activeRanges.remove_if([&](LiveRange *activeRange) {
554 // Check for live ranges that have expired.
555 if (activeRange->end() <= currentPoint) {
556 tileAllocator.releaseTileId(activeRange->getTileType(),
557 *activeRange->tileId);
558 return true;
559 }
560 // Check for live ranges that have become inactive.
561 if (!activeRange->overlaps(currentPoint)) {
562 tileAllocator.releaseTileId(activeRange->getTileType(),
563 *activeRange->tileId);
564 inactiveRanges.insert(activeRange);
565 return true;
566 }
567 return false;
568 });
569 // 2. Update the `inactiveRanges` at `currentPoint`.
570 inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
571 // Check for live ranges that have expired.
572 if (inactiveRange->end() <= currentPoint) {
573 return true;
574 }
575 // Check for live ranges that have become active.
576 if (inactiveRange->overlaps(currentPoint)) {
577 tileAllocator.acquireTileId(inactiveRange->getTileType(),
578 *inactiveRange->tileId);
579 activeRanges.insert(inactiveRange);
580 return true;
581 }
582 return false;
583 });
584
585 // 3. Collect inactive live ranges that overlap with the new live range.
586 // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
587 // whereas this checks if there is an overlap at any future point too.
588 SmallVector<LiveRange *> overlappingInactiveRanges;
589 for (LiveRange *inactiveRange : inactiveRanges) {
590 if (inactiveRange->overlaps(*nextRange)) {
591 // We need to reserve the tile IDs of overlapping inactive ranges to
592 // prevent two (overlapping) live ranges from getting the same tile ID.
593 tileAllocator.acquireTileId(inactiveRange->getTileType(),
594 *inactiveRange->tileId);
595 overlappingInactiveRanges.push_back(inactiveRange);
596 }
597 }
598
599 // 4. Allocate a tile ID to `nextRange`.
600 auto rangeTileType = nextRange->getTileType();
601 auto tileId = tileAllocator.allocateTileId(rangeTileType);
602 if (succeeded(tileId)) {
603 nextRange->tileId = *tileId;
604 } else {
605 // Create an iterator over all overlapping live ranges.
606 auto allOverlappingRanges = llvm::concat<LiveRange>(
607 llvm::make_pointee_range(activeRanges.getArrayRef()),
608 llvm::make_pointee_range(overlappingInactiveRanges));
609 // Choose an overlapping live range to spill.
610 LiveRange *rangeToSpill =
611 chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
612 if (rangeToSpill != nextRange) {
613 // Spill an (in)active live range (so release its tile ID first).
614 tileAllocator.releaseTileId(rangeToSpill->getTileType(),
615 *rangeToSpill->tileId);
616 // This will always succeed after a spill (of an active live range).
617 nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
618 // Remove the live range from the active/inactive sets.
619 if (!activeRanges.remove(rangeToSpill)) {
620 bool removed = inactiveRanges.remove(rangeToSpill);
621 assert(removed && "expected a range to be removed!");
622 (void)removed;
623 }
624 }
625 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
626 }
627
628 // 5. Insert the live range into the active ranges.
629 if (nextRange->tileId < kInMemoryTileIdBase)
630 activeRanges.insert(nextRange);
631
632 // 6. Release tiles reserved for inactive live ranges (in step 3).
633 for (LiveRange *range : overlappingInactiveRanges) {
634 if (*range->tileId < kInMemoryTileIdBase)
635 tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
636 }
637 }
638}
639
640/// Assigns a tile ID to an MLIR value.
641void assignTileIdToValue(IRRewriter &rewriter, Value value,
642 IntegerAttr tileIdAttr) {
643 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
644 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
645 for (Operation *user : value.getUsers()) {
646 if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
647 // Ensure ArmSME ops that don't produce a value still get a tile ID.
648 if (!hasTileResult(tileOp))
649 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
650 }
651 }
652}
653
654/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
655LogicalResult assignTileIdsAndResolveTrivialConflicts(
656 IRRewriter &rewriter, FunctionOpInterface function,
657 ArrayRef<LiveRange *> allocatedLiveRanges) {
658 for (LiveRange const *liveRange : allocatedLiveRanges) {
659 auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
660 auto isAllocatedToSameTile = [&](Value value) {
661 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
662 tileOp && tileOp.getTileId() == tileIdAttr)
663 return true;
664 return liveRange->values.contains(value);
665 };
666
667 /// Eliminates copies where the operand has the same tile ID.
668 auto foldRedundantCopies = [&](Value value) -> LogicalResult {
669 auto copyOp = value.getDefiningOp<CopyTileOp>();
670 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
671 return failure();
672 rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
673 return success();
674 };
675
676 /// Validates each predecessor to a tile block argument has been assigned
677 /// the same tile ID.
678 auto validateBlockArguments = [&](Value value) {
679 auto blockArg = dyn_cast<BlockArgument>(value);
680 if (!blockArg) {
681 // Not a block argument (nothing to validate).
682 return success();
683 }
684 bool tileMismatch = false;
685 forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
686 if (tileMismatch)
687 return;
688 if (!isAllocatedToSameTile(predecessorTile)) {
689 blockArg.getOwner()->getParentOp()->emitOpError(
690 "block argument not allocated to the same SME virtial tile as "
691 "predecessors");
692 tileMismatch = true;
693 }
694 });
695 return success(/*isSuccess=*/!tileMismatch);
696 };
697
698 /// Attempts to resolve (trivial) tile ID conflicts.
699 auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
700 auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
701 OpOperand *tileOperand = getTileOpOperand(tileOp);
702 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
703 // Operand already allocated to the correct tile.
704 // No conflict to resolve.
705 return success();
706 }
707 auto operandTileOp =
708 tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
709 if (!isTriviallyCloneableTileOp(operandTileOp)) {
710 auto error =
711 tileOp.emitOpError("tile operand allocated to different SME "
712 "virtial tile (move required)");
713 error.attachNote(tileOperand->get().getLoc())
714 << "tile operand is: " << tileOperand->get();
715 return error;
716 }
717 // Cloning prevents a move/spill (though may require recomputation).
718 rewriter.setInsertionPoint(tileOp);
719 auto clonedOp = operandTileOp.clone();
720 rewriter.modifyOpInPlace(clonedOp,
721 [&] { clonedOp.setTileId(tileOp.getTileId()); });
722 rewriter.insert(clonedOp);
723 if (isa<CopyTileOp>(tileOp)) {
724 rewriter.replaceAllUsesWith(tileOp->getResult(0),
725 clonedOp->getResult(0));
726 } else {
727 rewriter.modifyOpInPlace(
728 tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
729 }
730 return success();
731 };
732
733 for (Value value : liveRange->values) {
734 // 1. Assign the tile ID to the value.
735 assignTileIdToValue(rewriter, value, tileIdAttr);
736
737 // 2. Attempt to eliminate redundant tile copies.
738 if (succeeded(foldRedundantCopies(value)))
739 continue;
740
741 // 3. Validate tile block arguments.
742 if (failed(validateBlockArguments(value)))
743 return failure();
744
745 // 4. Attempt to resolve (trivial) tile ID conflicts.
746 if (failed(resolveTrivialTileConflicts(value)))
747 return failure();
748 }
749 }
750 return success();
751}
752
753/// Prints live ranges alongside operation names for debugging.
754void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
756 FunctionOpInterface function) {
757 llvm::errs() << "SME Tile Liveness: @" << function.getName()
758 << "\nKey:\nS - Start\nE - End\n| - Live\n";
759 for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
760 llvm::errs() << "^bb" << blockIdx << ":\n";
761 for (Operation &op : block.getOperations()) {
762 unsigned operationIndex = operationToIndexMap.at(&op);
763 for (LiveRange const *range : liveRanges) {
764 char liveness = ' ';
765 for (auto it = range->ranges->begin(); it != range->ranges->end();
766 ++it) {
767 if (it.start() == operationIndex)
768 liveness = (liveness == 'E' ? '|' : 'S');
769 else if (it.stop() == operationIndex)
770 liveness = (liveness == 'S' ? '|' : 'E');
771 else if (operationIndex >= it.start() && operationIndex < it.stop())
772 liveness = '|';
773 }
774 llvm::errs() << liveness;
775 }
776 llvm::errs() << ' ' << op.getName() << '\n';
777 }
778 }
779 llvm::errs() << "==========\n";
780}
781
782struct TestTileAllocationPass
783 : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
784 using TestTileAllocationBase::TestTileAllocationBase;
785 void runOnOperation() override {
786 FunctionOpInterface function = getOperation();
787 if (preprocessOnly) {
788 IRRewriter rewriter(function);
789 return preprocessForTileAllocation(rewriter, function);
790 }
791 if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
792 signalPassFailure();
793 }
794};
795} // namespace
796
797LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
798 bool dumpRanges) {
799 if (function.empty()) {
800 // TODO: Also return early if the function contains no ArmSME ops?
801 return success();
802 }
803
804 LiveRange::Allocator liveRangeAllocator;
805 IRRewriter rewriter(function.getContext());
806
807 // 1. Preprocess the IR for tile allocation.
808 preprocessForTileAllocation(rewriter, function);
809
810 // 2. Gather live ranges for each ArmSME tile within the function.
811 Liveness liveness(function);
812 auto operationToIndexMap = generateOperationNumbering(function);
813 auto initialLiveRanges = gatherTileLiveRanges(
814 operationToIndexMap, liveRangeAllocator, liveness, function);
815 if (initialLiveRanges.empty())
816 return success();
817
818 if (dumpRanges) {
819 // Wrangle initial live ranges into a form suitable for printing.
820 auto nonEmpty = llvm::make_filter_range(
821 llvm::make_second_range(initialLiveRanges),
822 [&](LiveRange const &liveRange) { return !liveRange.empty(); });
823 auto initialRanges = llvm::map_to_vector(
824 nonEmpty, [](LiveRange const &liveRange) { return &liveRange; });
825 llvm::sort(initialRanges,
826 [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
827 llvm::errs() << "\n========== Initial Live Ranges:\n";
828 dumpLiveRanges(operationToIndexMap, initialRanges, function);
829 }
830
831 // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
832 // for tile allocation. E.g. Unify the result of an operation with its
833 // operands.
834 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
835
836 if (dumpRanges) {
837 llvm::errs() << "\n========== Coalesced Live Ranges:\n";
838 dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
839 }
840
841 // 4. Allocate tile IDs to live ranges.
842 allocateTilesToLiveRanges(coalescedLiveRanges);
843
844 // 5. Assign the tile IDs back to the ArmSME operations.
845 if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
846 coalescedLiveRanges))) {
847 return failure();
848 }
849
850 // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
851 // users). This prevents the LLVM conversion needlessly inserting spills.
852 eraseTriviallyDeadTileOps(rewriter, function);
853 return success();
854}
return success()
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents liveness information on block level.
Definition Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition Liveness.h:110
Represents an analysis for computing liveness information from a given top-level operation.
Definition Liveness.h:47
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
Definition Liveness.cpp:225
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:425
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition Utils.cpp:58
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition Utils.cpp:133
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition Utils.cpp:153
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition Utils.cpp:181
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
static constexpr unsigned kInMemoryTileIdBase
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition Utils.cpp:166
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition Utils.cpp:158
bool operator<(const Fraction &x, const Fraction &y)
Definition Fraction.h:83
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144