MLIR 23.0.0git
TileAllocation.cpp
Go to the documentation of this file.
1//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transform allocates SME tiles at the 'func.func' op level for ArmSME
10// operations. It roughly implements a linear scan register allocator, similar
11// to the one outlined in [1], but with simplifications and assumptions made for
12// our use case. Note that this is a greedy allocator (so it may not always find
13// the most optimal allocation of tiles).
14//
15// The allocator operates at the CF dialect level. It is the responsibility of
16// users to ensure the IR has been lowered to CF before invoking the tile
17// allocator.
18//
19// The 128-bit tiles overlap with other element tiles as follows (see section
20// B2.3.2 of SME spec [2]):
21//
22// Tile Overlaps
23// ---------------------------------------------------------------------------
24// ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
25// ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
26// ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
27// ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
28// ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
29// ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
30// ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
31// ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
32// ZA0.D ZA0.Q, ZA8.Q
33// ZA1.D ZA1.Q, ZA9.Q
34// ZA2.D ZA2.Q, ZA10.Q
35// ZA3.D ZA3.Q, ZA11.Q
36// ZA4.D ZA4.Q, ZA12.Q
37// ZA5.D ZA5.Q, ZA13.Q
38// ZA6.D ZA6.Q, ZA14.Q
39// ZA7.D ZA7.Q, ZA15.Q
40//
41// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42// Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43// https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44// [2] https://developer.arm.com/documentation/ddi0616/aa
45//
46//===----------------------------------------------------------------------===//
47
55#include "llvm/ADT/IntervalMap.h"
56#include "llvm/ADT/SmallVectorExtras.h"
57#include "llvm/ADT/TypeSwitch.h"
58
59namespace mlir::arm_sme {
60#define GEN_PASS_DEF_TESTTILEALLOCATION
61#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
62} // namespace mlir::arm_sme
63
64using namespace mlir;
65using namespace mlir::arm_sme;
66
67namespace {
68
69enum class TileMask : unsigned {
70 // clang-format off
71 kZA0B = 0xffff, // 1111 1111 1111 1111
72
73 kZA0H = 0xaaaa, // 1010 1010 1010 1010
74 kZA1H = 0x5555, // 0101 0101 0101 0101
75
76 kZA0S = 0x8888, // 1000 1000 1000 1000
77 kZA1S = 0x4444, // 0100 0100 0100 0100
78 kZA2S = 0x2222, // 0010 0010 0010 0010
79 kZA3S = 0x1111, // 0001 0001 0001 0001
80
81 kZA0D = 0x8080, // 1000 0000 1000 0000
82 kZA1D = 0x4040, // 0100 0000 0100 0000
83 kZA2D = 0x2020, // 0010 0000 0010 0000
84 kZA3D = 0x1010, // 0001 0000 0001 0000
85 kZA4D = 0x808, // 0000 1000 0000 1000
86 kZA5D = 0x404, // 0000 0100 0000 0100
87 kZA6D = 0x202, // 0000 0010 0000 0010
88 kZA7D = 0x101, // 0000 0001 0000 0001
89
90 kZA0Q = 0x8000, // 1000 0000 0000 0000
91 kZA1Q = 0x4000, // 0100 0000 0000 0000
92 kZA2Q = 0x2000, // 0010 0000 0000 0000
93 kZA3Q = 0x1000, // 0001 0000 0000 0000
94 kZA4Q = 0x800, // 0000 1000 0000 0000
95 kZA5Q = 0x400, // 0000 0100 0000 0000
96 kZA6Q = 0x200, // 0000 0010 0000 0000
97 kZA7Q = 0x100, // 0000 0001 0000 0000
98 kZA8Q = 0x80, // 0000 0000 1000 0000
99 kZA9Q = 0x40, // 0000 0000 0100 0000
100 kZA10Q = 0x20, // 0000 0000 0010 0000
101 kZA11Q = 0x10, // 0000 0000 0001 0000
102 kZA12Q = 0x8, // 0000 0000 0000 1000
103 kZA13Q = 0x4, // 0000 0000 0000 0100
104 kZA14Q = 0x2, // 0000 0000 0000 0010
105 kZA15Q = 0x1, // 0000 0000 0000 0001
106
107 kNone = 0x0, // 0000 0000 0000 0000
108 // clang-format on
109
111};
112
113/// Returns the set of masks relevant for the given type.
114static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
115 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
116 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
117 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
118 TileMask::kZA2S, TileMask::kZA3S};
119 static constexpr std::array ZA_D_MASKS = {
120 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
121 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
122 static constexpr std::array ZA_Q_MASKS = {
123 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
124 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
125 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
126 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
127 switch (type) {
128 case ArmSMETileType::ZAB:
129 return ZA_B_MASKS;
130 case ArmSMETileType::ZAH:
131 return ZA_H_MASKS;
132 case ArmSMETileType::ZAS:
133 return ZA_S_MASKS;
134 case ArmSMETileType::ZAD:
135 return ZA_D_MASKS;
136 case ArmSMETileType::ZAQ:
137 return ZA_Q_MASKS;
138 }
139 llvm_unreachable("unknown type in getMasks");
140}
141
142class TileAllocator {
143public:
144 /// Allocates and returns a tile ID. Fails if there are no tiles left.
145 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
146 auto masks = getMasks(tileType);
147 for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
148 if ((tilesInUse & tileMask) == TileMask::kNone) {
149 tilesInUse |= tileMask;
150 return tileId;
151 }
152 }
153 return failure();
154 }
155
156 /// Acquires a specific tile ID. Asserts the tile is initially free.
157 void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
158 TileMask tileMask = getMasks(tileType)[tileId];
159 assert((tilesInUse & tileMask) == TileMask::kNone &&
160 "cannot acquire allocated tile!");
161 tilesInUse |= tileMask;
162 }
163
164 /// Releases a previously allocated tile ID.
165 void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
166 TileMask tileMask = getMasks(tileType)[tileId];
167 assert((tilesInUse & tileMask) == tileMask &&
168 "cannot release unallocated tile!");
169 tilesInUse ^= tileMask;
170 }
171
172 /// Allocates an in-memory tile ID.
173 unsigned allocateInMemoryTileId() {
174 // Note: We never release in-memory tile IDs. We could, which may allow
175 // reusing an allocation, but as we _never_ want to spill an SME tile this
176 // is not optimized.
177 return nextInMemoryTileId++;
178 }
179
180private:
181 TileMask tilesInUse = TileMask::kNone;
182 unsigned nextInMemoryTileId = kInMemoryTileIdBase;
183};
184
185/// Add new intermediate blocks for the true and false destinations of
186/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
187/// overlaps due to copies at branches.
188///
189/// BEFORE:
190/// ```mlir
191/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
192/// ```
193///
194/// AFTER:
195/// ```mlir
196/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
197/// ^bb1_copy:
198/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
199/// ^bb2_copy:
200/// cf.br ^bb2
201/// ```
202void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
204 function.walk([&](cf::CondBranchOp condBranch) {
205 if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
206 return isValidSMETileVectorType(value.getType());
207 })) {
208 worklist.push_back(condBranch);
209 }
210 });
211
212 auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
213 rewriter.setInsertionPointToEnd(source);
214 cf::BranchOp::create(rewriter, loc, dest, args);
215 };
216
217 for (auto condBranch : worklist) {
218 auto loc = condBranch.getLoc();
219 Block *block = condBranch->getBlock();
220 auto *newTrueBranch = rewriter.splitBlock(block, block->end());
221 auto *newFalseBranch = rewriter.splitBlock(block, block->end());
222 insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
223 condBranch.getTrueDestOperands());
224 insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
225 condBranch.getFalseDestOperands());
226 rewriter.modifyOpInPlace(condBranch, [&] {
227 condBranch.getFalseDestOperandsMutable().clear();
228 condBranch.getTrueDestOperandsMutable().clear();
229 condBranch.setSuccessor(newTrueBranch, 0);
230 condBranch.setSuccessor(newFalseBranch, 1);
231 });
232 }
233}
234
235/// Inserts tile copies at `cf.br` operations.
236///
237/// BEFORE:
238/// ```mlir
239/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
240/// ```
241///
242/// AFTER:
243/// ```mlir
244/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
245/// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
246/// ```
247void insertCopiesAtBranches(IRRewriter &rewriter,
248 FunctionOpInterface function) {
249 for (Block &block : function.getBlocks()) {
250 Operation *terminator = block.getTerminator();
251 if (!isa<cf::BranchOp>(terminator))
252 continue;
253 rewriter.setInsertionPoint(terminator);
254 for (OpOperand &operand : terminator->getOpOperands()) {
255 if (isValidSMETileVectorType(operand.get().getType())) {
256 auto copy =
257 CopyTileOp::create(rewriter, terminator->getLoc(), operand.get());
258 rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
259 }
260 }
261 }
262}
263
264/// Prepares the IR for tile allocation. It does this by first 'splitting'
265/// conditional branches (see `splitCondBranches`), then inserting tile copies
266/// at branch operations. The conditional branches are split to prevent the
267/// copies needed for them overlapping between the true and false paths of the
268/// branch (see `tile-allocation-copies.mlir` and
269/// `tile-allocation-liveness.mlir` for examples). The copies break up live
270/// ranges and ensure when moving out of SSA the semantics of the program are
271/// preserved.
272void preprocessForTileAllocation(IRRewriter &rewriter,
273 FunctionOpInterface function) {
274 splitCondBranches(rewriter, function);
275 insertCopiesAtBranches(rewriter, function);
276}
277
278/// A live range for a (collection of) tile values. A live range is built up of
279/// non-overlapping intervals [start, end) which represent parts of the program
280/// where a value in the range needs to be live (i.e. in an SME virtual tile).
281/// Note that as the intervals are non-overlapping all values within a live
282/// range can be allocated to the same SME virtual tile.
283struct LiveRange {
284 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
285 llvm::IntervalMapHalfOpenInfo<unsigned>>;
286 using Allocator = RangeSet::Allocator;
287 // Dummy value for the IntervalMap. Only the keys matter (the intervals).
288 static constexpr uint8_t kValidLiveRange = 0xff;
289
290 LiveRange(Allocator &allocator)
291 : ranges(std::make_unique<RangeSet>(allocator)) {}
292
293 /// Returns true if this range overlaps with `otherRange`.
294 bool overlaps(LiveRange const &otherRange) const {
295 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
296 *otherRange.ranges)
297 .valid();
298 }
299
300 /// Returns true if this range is active at `point` in the program.
301 bool overlaps(uint64_t point) const {
302 return ranges->lookup(point) == kValidLiveRange;
303 }
304
305 /// Unions this live range with `otherRange`, aborts if the ranges overlap.
306 void unionWith(LiveRange const &otherRange) {
307 for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
308 ++it)
309 ranges->insert(it.start(), it.stop(), kValidLiveRange);
310 values.set_union(otherRange.values);
311 }
312
313 /// Inserts an interval [start, end) for `value` into this range.
314 void insert(Value value, unsigned start, unsigned end) {
315 values.insert(value);
316 if (start != end)
317 ranges->insert(start, end, kValidLiveRange);
318 }
319
320 bool empty() const { return ranges->empty(); }
321 unsigned start() const { return ranges->start(); }
322 unsigned end() const { return ranges->stop(); }
323 bool operator<(LiveRange const &other) const {
324 return start() < other.start();
325 }
326
327 ArmSMETileType getTileType() const {
328 return *getSMETileType(cast<VectorType>(values[0].getType()));
329 }
330
331 /// The values contained in this live range.
332 SetVector<Value> values;
333
334 /// A set of (non-overlapping) intervals that mark where any value in `values`
335 /// is live.
336 std::unique_ptr<RangeSet> ranges;
337
338 /// The tile ID (or none) assigned to this live range.
339 std::optional<unsigned> tileId;
340};
341
342/// Number operations within a function to allow computing live ranges.
343/// Operations are numbered consecutively wihin blocks, and the blocks are
344/// topologically sorted (using forward edges). This function is only correct if
345/// all ArmSME have been converted to CF (which is asserted).
346FailureOr<DenseMap<Operation *, unsigned>>
347generateOperationNumbering(FunctionOpInterface function) {
348 unsigned index = 0;
349 SetVector<Block *> blocks =
350 getBlocksSortedByDominance(function.getFunctionBody());
351 DenseMap<Operation *, unsigned> operationToIndexMap;
352 for (Block *block : blocks) {
353 index++; // We want block args to have their own number.
354 for (Operation &op : block->getOperations()) {
355 WalkResult walkResult =
356 op.walk([&](ArmSMETileOpInterface nestedOp) -> WalkResult {
357 if (&op != nestedOp.getOperation())
358 return WalkResult::interrupt();
359 return WalkResult::advance();
360 });
361 if (walkResult.wasInterrupted()) {
362 return op.emitError("ArmSME tile allocation requires flattened control "
363 "flow; run -convert-scf-to-cf before this pass "
364 "(e.g. via convert-arm-sme-to-llvm pipeline)");
365 }
366 operationToIndexMap.try_emplace(&op, index++);
367 }
368 }
369
370 return operationToIndexMap;
371}
372
373/// Gather live ranges for SME tiles from the MLIR liveness analysis.
375gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
376 LiveRange::Allocator &liveRangeAllocator,
377 Liveness &liveness, FunctionOpInterface function) {
378 assert(!operationToIndexMap.empty() && "expected operation numbering");
380 /// Defines or updates a live range for an SME tile value. Live-ins may update
381 /// an existing live range (rather than define a new one). Note: If
382 /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
383 /// the block.
384 auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
385 LivenessBlockInfo const &livenessInfo,
386 bool liveAtBlockEntry = false) {
387 if (!isValidSMETileVectorType(value.getType()))
388 return;
389 // Find or create a live range for `value`.
390 auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
391 LiveRange &valueLiveRange = it->second;
392 auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
393 // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
394 unsigned startOpIdx =
395 operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
396 unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
397 valueLiveRange.insert(value, startOpIdx, endOpIdx);
398 };
399
400 for (Block &block : function.getBlocks()) {
401 LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
402 // Handle block arguments:
403 for (Value argument : block.getArguments())
404 defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
405 /*liveAtBlockEntry=*/true);
406 // Handle live-ins:
407 for (Value liveIn : livenessInfo->in())
408 defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
409 /*liveAtBlockEntry=*/true);
410 // Handle new definitions:
411 for (Operation &op : block) {
412 for (Value result : op.getResults())
413 defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
414 }
415 }
416
417 return liveRanges;
418}
419
420/// Iterate over all predecessor tile values to a (tile) block argument.
421static void forEachPredecessorTileValue(BlockArgument blockArg,
422 function_ref<void(Value)> callback) {
423 Block *block = blockArg.getOwner();
424 unsigned argNumber = blockArg.getArgNumber();
425 for (Block *pred : block->getPredecessors()) {
426 TypeSwitch<Operation *>(pred->getTerminator())
427 .Case([&](cf::BranchOp branch) {
428 Value predecessorOperand = branch.getDestOperands()[argNumber];
429 callback(predecessorOperand);
430 })
431 .Case([&](cf::CondBranchOp condBranch) {
432 if (condBranch.getFalseDest() == block) {
433 Value predecessorOperand =
434 condBranch.getFalseDestOperands()[argNumber];
435 callback(predecessorOperand);
436 }
437 if (condBranch.getTrueDest() == block) {
438 Value predecessorOperand =
439 condBranch.getTrueDestOperands()[argNumber];
440 callback(predecessorOperand);
441 }
442 });
443 }
444}
445
446/// Coalesce live ranges where it would prevent unnecessary tile moves.
448coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
450 for (auto &[value, liveRange] : initialLiveRanges) {
451 liveRanges.insert({value, &liveRange});
452 }
453
454 // Merge the live ranges of values `a` and `b` into one (if they do not
455 // overlap). After this, the values `a` and `b` will both point to the same
456 // live range (which will contain multiple values).
457 auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
458 LiveRange *aLiveRange = liveRanges.at(a);
459 LiveRange *bLiveRange = liveRanges.at(b);
460 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
461 aLiveRange->unionWith(*bLiveRange);
462 for (Value value : bLiveRange->values)
463 liveRanges[value] = aLiveRange;
464 }
465 };
466
467 // Merge the live ranges of new definitions with their tile operands.
468 auto unifyDefinitionsWithOperands = [&](Value value) {
469 auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
470 if (!armSMEOp)
471 return;
472 for (auto operand : armSMEOp->getOperands()) {
473 if (isValidSMETileVectorType(operand.getType()))
474 mergeValuesIfNonOverlapping(value, operand);
475 }
476 };
477
478 // Merge the live ranges of block arguments with their predecessors.
479 auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
480 auto blockArg = dyn_cast<BlockArgument>(value);
481 if (!blockArg)
482 return;
483 forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
484 mergeValuesIfNonOverlapping(blockArg, predecessorTile);
485 });
486 };
487
488 auto applyRule = [&](auto rule) {
489 llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
490 };
491
492 // Unify as many live ranges as we can. This prevents unnecessary moves.
493 applyRule(unifyBlockArgumentsWithPredecessors);
494 applyRule(unifyDefinitionsWithOperands);
495
496 // Remove duplicate live range entries.
497 SetVector<LiveRange *> uniqueLiveRanges;
498 for (auto [_, liveRange] : liveRanges) {
499 if (!liveRange->empty())
500 uniqueLiveRanges.insert(liveRange);
501 }
502
503 // Sort the new live ranges by starting point (ready for tile allocation).
504 auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
505 llvm::sort(coalescedLiveRanges,
506 [](LiveRange *a, LiveRange *b) { return *a < *b; });
507 return std::move(coalescedLiveRanges);
508}
509
510/// Choose a live range to spill (via some heuristics). This picks either a live
511/// range from `overlappingRanges`, or the new live range `newRange`.
512template <typename OverlappingRangesIterator>
513LiveRange *
514chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
515 LiveRange *newRange) {
516 // Heuristic: Spill trivially copyable operations (usually free).
517 auto isTrivialSpill = [&](LiveRange &allocatedRange) {
518 return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
519 newRange->getTileType()) &&
520 allocatedRange.values.size() == 1 &&
522 allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
523 };
524 if (isTrivialSpill(*newRange))
525 return newRange;
526 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
527 if (trivialSpill != overlappingRanges.end())
528 return &*trivialSpill;
529
530 // Heuristic: Spill the range that ends last (with a compatible tile type).
531 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
532 return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
533 a.end() < b.end();
534 };
535 LiveRange &latestEndingLiveRange =
536 *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
537 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
538 return &latestEndingLiveRange;
539 return newRange;
540}
541
542/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
543void allocateTilesToLiveRanges(
544 ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
545 TileAllocator tileAllocator;
546 // `activeRanges` = Live ranges that need to be in a tile at the
547 // `currentPoint` in the program.
548 SetVector<LiveRange *> activeRanges;
549 // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
550 // at the `currentPoint` in the program but could become active again later.
551 // An inactive section of a live range can be seen as a 'hole' in the live
552 // range, where it is possible to reuse the live range's tile ID _before_ it
553 // has ended. By identifying 'holes', the allocator can reuse tiles more
554 // often, which helps avoid costly tile spills.
555 SetVector<LiveRange *> inactiveRanges;
556 for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
557 auto currentPoint = nextRange->start();
558 // 1. Update the `activeRanges` at `currentPoint`.
559 activeRanges.remove_if([&](LiveRange *activeRange) {
560 // Check for live ranges that have expired.
561 if (activeRange->end() <= currentPoint) {
562 tileAllocator.releaseTileId(activeRange->getTileType(),
563 *activeRange->tileId);
564 return true;
565 }
566 // Check for live ranges that have become inactive.
567 if (!activeRange->overlaps(currentPoint)) {
568 tileAllocator.releaseTileId(activeRange->getTileType(),
569 *activeRange->tileId);
570 inactiveRanges.insert(activeRange);
571 return true;
572 }
573 return false;
574 });
575 // 2. Update the `inactiveRanges` at `currentPoint`.
576 inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
577 // Check for live ranges that have expired.
578 if (inactiveRange->end() <= currentPoint) {
579 return true;
580 }
581 // Check for live ranges that have become active.
582 if (inactiveRange->overlaps(currentPoint)) {
583 tileAllocator.acquireTileId(inactiveRange->getTileType(),
584 *inactiveRange->tileId);
585 activeRanges.insert(inactiveRange);
586 return true;
587 }
588 return false;
589 });
590
591 // 3. Collect inactive live ranges that overlap with the new live range.
592 // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
593 // whereas this checks if there is an overlap at any future point too.
594 SmallVector<LiveRange *> overlappingInactiveRanges;
595 for (LiveRange *inactiveRange : inactiveRanges) {
596 if (inactiveRange->overlaps(*nextRange)) {
597 // We need to reserve the tile IDs of overlapping inactive ranges to
598 // prevent two (overlapping) live ranges from getting the same tile ID.
599 tileAllocator.acquireTileId(inactiveRange->getTileType(),
600 *inactiveRange->tileId);
601 overlappingInactiveRanges.push_back(inactiveRange);
602 }
603 }
604
605 // 4. Allocate a tile ID to `nextRange`.
606 auto rangeTileType = nextRange->getTileType();
607 auto tileId = tileAllocator.allocateTileId(rangeTileType);
608 if (succeeded(tileId)) {
609 nextRange->tileId = *tileId;
610 } else {
611 // Create an iterator over all overlapping live ranges.
612 auto allOverlappingRanges = llvm::concat<LiveRange>(
613 llvm::make_pointee_range(activeRanges.getArrayRef()),
614 llvm::make_pointee_range(overlappingInactiveRanges));
615 // Choose an overlapping live range to spill.
616 LiveRange *rangeToSpill =
617 chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
618 if (rangeToSpill != nextRange) {
619 // Spill an (in)active live range (so release its tile ID first).
620 tileAllocator.releaseTileId(rangeToSpill->getTileType(),
621 *rangeToSpill->tileId);
622 // This will always succeed after a spill (of an active live range).
623 nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
624 // Remove the live range from the active/inactive sets.
625 if (!activeRanges.remove(rangeToSpill)) {
626 bool removed = inactiveRanges.remove(rangeToSpill);
627 assert(removed && "expected a range to be removed!");
628 (void)removed;
629 }
630 }
631 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
632 }
633
634 // 5. Insert the live range into the active ranges.
635 if (nextRange->tileId < kInMemoryTileIdBase)
636 activeRanges.insert(nextRange);
637
638 // 6. Release tiles reserved for inactive live ranges (in step 3).
639 for (LiveRange *range : overlappingInactiveRanges) {
640 if (*range->tileId < kInMemoryTileIdBase)
641 tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
642 }
643 }
644}
645
646/// Assigns a tile ID to an MLIR value.
647void assignTileIdToValue(IRRewriter &rewriter, Value value,
648 IntegerAttr tileIdAttr) {
649 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
650 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
651 for (Operation *user : value.getUsers()) {
652 if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
653 // Ensure ArmSME ops that don't produce a value still get a tile ID.
654 if (!hasTileResult(tileOp))
655 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
656 }
657 }
658}
659
660/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
661LogicalResult assignTileIdsAndResolveTrivialConflicts(
662 IRRewriter &rewriter, FunctionOpInterface function,
663 ArrayRef<LiveRange *> allocatedLiveRanges) {
664 for (LiveRange const *liveRange : allocatedLiveRanges) {
665 auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
666 auto isAllocatedToSameTile = [&](Value value) {
667 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
668 tileOp && tileOp.getTileId() == tileIdAttr)
669 return true;
670 return liveRange->values.contains(value);
671 };
672
673 /// Eliminates copies where the operand has the same tile ID.
674 auto foldRedundantCopies = [&](Value value) -> LogicalResult {
675 auto copyOp = value.getDefiningOp<CopyTileOp>();
676 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
677 return failure();
678 rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
679 return success();
680 };
681
682 /// Validates each predecessor to a tile block argument has been assigned
683 /// the same tile ID.
684 auto validateBlockArguments = [&](Value value) {
685 auto blockArg = dyn_cast<BlockArgument>(value);
686 if (!blockArg) {
687 // Not a block argument (nothing to validate).
688 return success();
689 }
690 bool tileMismatch = false;
691 forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
692 if (tileMismatch)
693 return;
694 if (!isAllocatedToSameTile(predecessorTile)) {
695 blockArg.getOwner()->getParentOp()->emitOpError(
696 "block argument not allocated to the same SME virtial tile as "
697 "predecessors");
698 tileMismatch = true;
699 }
700 });
701 return success(/*isSuccess=*/!tileMismatch);
702 };
703
704 /// Attempts to resolve (trivial) tile ID conflicts.
705 auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
706 auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
707 OpOperand *tileOperand = getTileOpOperand(tileOp);
708 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
709 // Operand already allocated to the correct tile.
710 // No conflict to resolve.
711 return success();
712 }
713 auto operandTileOp =
714 tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
715 if (!isTriviallyCloneableTileOp(operandTileOp)) {
716 auto error =
717 tileOp.emitOpError("tile operand allocated to different SME "
718 "virtial tile (move required)");
719 error.attachNote(tileOperand->get().getLoc())
720 << "tile operand is: " << tileOperand->get();
721 return error;
722 }
723 // Cloning prevents a move/spill (though may require recomputation).
724 rewriter.setInsertionPoint(tileOp);
725 auto clonedOp = operandTileOp.clone();
726 rewriter.modifyOpInPlace(clonedOp,
727 [&] { clonedOp.setTileId(tileOp.getTileId()); });
728 rewriter.insert(clonedOp);
729 if (isa<CopyTileOp>(tileOp)) {
730 rewriter.replaceAllUsesWith(tileOp->getResult(0),
731 clonedOp->getResult(0));
732 } else {
733 rewriter.modifyOpInPlace(
734 tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
735 }
736 return success();
737 };
738
739 for (Value value : liveRange->values) {
740 // 1. Assign the tile ID to the value.
741 assignTileIdToValue(rewriter, value, tileIdAttr);
742
743 // 2. Attempt to eliminate redundant tile copies.
744 if (succeeded(foldRedundantCopies(value)))
745 continue;
746
747 // 3. Validate tile block arguments.
748 if (failed(validateBlockArguments(value)))
749 return failure();
750
751 // 4. Attempt to resolve (trivial) tile ID conflicts.
752 if (failed(resolveTrivialTileConflicts(value)))
753 return failure();
754 }
755 }
756 return success();
757}
758
759/// Prints live ranges alongside operation names for debugging.
760void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
762 FunctionOpInterface function) {
763 llvm::errs() << "SME Tile Liveness: @" << function.getName()
764 << "\nKey:\nS - Start\nE - End\n| - Live\n";
765 for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
766 llvm::errs() << "^bb" << blockIdx << ":\n";
767 for (Operation &op : block.getOperations()) {
768 unsigned operationIndex = operationToIndexMap.at(&op);
769 for (LiveRange const *range : liveRanges) {
770 char liveness = ' ';
771 for (auto it = range->ranges->begin(); it != range->ranges->end();
772 ++it) {
773 if (it.start() == operationIndex)
774 liveness = (liveness == 'E' ? '|' : 'S');
775 else if (it.stop() == operationIndex)
776 liveness = (liveness == 'S' ? '|' : 'E');
777 else if (operationIndex >= it.start() && operationIndex < it.stop())
778 liveness = '|';
779 }
780 llvm::errs() << liveness;
781 }
782 llvm::errs() << ' ' << op.getName() << '\n';
783 }
784 }
785 llvm::errs() << "==========\n";
786}
787
788struct TestTileAllocationPass
789 : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
790 using TestTileAllocationBase::TestTileAllocationBase;
791 void runOnOperation() override {
792 FunctionOpInterface function = getOperation();
793 if (preprocessOnly) {
794 IRRewriter rewriter(function);
795 return preprocessForTileAllocation(rewriter, function);
796 }
797 if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
798 signalPassFailure();
799 }
800};
801} // namespace
802
803LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
804 bool dumpRanges) {
805 if (function.empty()) {
806 // TODO: Also return early if the function contains no ArmSME ops?
807 return success();
808 }
809
810 LiveRange::Allocator liveRangeAllocator;
811 IRRewriter rewriter(function.getContext());
812
813 // 1. Preprocess the IR for tile allocation.
814 preprocessForTileAllocation(rewriter, function);
815
816 // 2. Gather live ranges for each ArmSME tile within the function.
817 Liveness liveness(function);
818 auto maybeOperationToIndexMap = generateOperationNumbering(function);
819 if (failed(maybeOperationToIndexMap))
820 return failure();
821 auto &operationToIndexMap = *maybeOperationToIndexMap;
822 auto initialLiveRanges = gatherTileLiveRanges(
823 operationToIndexMap, liveRangeAllocator, liveness, function);
824 if (initialLiveRanges.empty())
825 return success();
826
827 if (dumpRanges) {
828 // Wrangle initial live ranges into a form suitable for printing.
829 auto nonEmpty = llvm::make_filter_range(
830 llvm::make_second_range(initialLiveRanges),
831 [&](LiveRange const &liveRange) { return !liveRange.empty(); });
832 auto initialRanges = llvm::map_to_vector(
833 nonEmpty, [](LiveRange const &liveRange) { return &liveRange; });
834 llvm::sort(initialRanges,
835 [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
836 llvm::errs() << "\n========== Initial Live Ranges:\n";
837 dumpLiveRanges(operationToIndexMap, initialRanges, function);
838 }
839
840 // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
841 // for tile allocation. E.g. Unify the result of an operation with its
842 // operands.
843 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
844
845 if (dumpRanges) {
846 llvm::errs() << "\n========== Coalesced Live Ranges:\n";
847 dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
848 }
849
850 // 4. Allocate tile IDs to live ranges.
851 allocateTilesToLiveRanges(coalescedLiveRanges);
852
853 // 5. Assign the tile IDs back to the ArmSME operations.
854 if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
855 coalescedLiveRanges))) {
856 return failure();
857 }
858
859 // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
860 // users). This prevents the LLVM conversion needlessly inserting spills.
861 eraseTriviallyDeadTileOps(rewriter, function);
862 return success();
863}
return success()
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents liveness information on block level.
Definition Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition Liveness.h:110
Represents an analysis for computing liveness information from a given top-level operation.
Definition Liveness.h:47
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
Definition Liveness.cpp:225
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:425
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:391
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition Utils.cpp:55
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition Utils.cpp:130
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition Utils.cpp:150
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition Utils.cpp:178
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
static constexpr unsigned kInMemoryTileIdBase
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition Utils.cpp:163
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition Utils.cpp:155
bool operator<(const Fraction &x, const Fraction &y)
Definition Fraction.h:83
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144