55#include "llvm/ADT/IntervalMap.h"
56#include "llvm/ADT/SmallVectorExtras.h"
57#include "llvm/ADT/TypeSwitch.h"
60#define GEN_PASS_DEF_TESTTILEALLOCATION
61#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
69enum class TileMask :
unsigned {
115 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
116 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
117 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
118 TileMask::kZA2S, TileMask::kZA3S};
119 static constexpr std::array ZA_D_MASKS = {
120 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
121 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
122 static constexpr std::array ZA_Q_MASKS = {
123 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
124 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
125 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
126 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
128 case ArmSMETileType::ZAB:
130 case ArmSMETileType::ZAH:
132 case ArmSMETileType::ZAS:
134 case ArmSMETileType::ZAD:
136 case ArmSMETileType::ZAQ:
139 llvm_unreachable(
"unknown type in getMasks");
145 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
146 auto masks = getMasks(tileType);
147 for (
auto [tileId, tileMask] : llvm::enumerate(masks)) {
148 if ((tilesInUse & tileMask) == TileMask::kNone) {
149 tilesInUse |= tileMask;
157 void acquireTileId(ArmSMETileType tileType,
unsigned tileId) {
158 TileMask tileMask = getMasks(tileType)[tileId];
159 assert((tilesInUse & tileMask) == TileMask::kNone &&
160 "cannot acquire allocated tile!");
161 tilesInUse |= tileMask;
165 void releaseTileId(ArmSMETileType tileType,
unsigned tileId) {
166 TileMask tileMask = getMasks(tileType)[tileId];
167 assert((tilesInUse & tileMask) == tileMask &&
168 "cannot release unallocated tile!");
169 tilesInUse ^= tileMask;
173 unsigned allocateInMemoryTileId() {
177 return nextInMemoryTileId++;
181 TileMask tilesInUse = TileMask::kNone;
202void splitCondBranches(
IRRewriter &rewriter, FunctionOpInterface function) {
204 function.walk([&](cf::CondBranchOp condBranch) {
205 if (llvm::any_of(condBranch->getOperands(), [&](
Value value) {
206 return isValidSMETileVectorType(value.getType());
208 worklist.push_back(condBranch);
214 cf::BranchOp::create(rewriter, loc, dest, args);
217 for (
auto condBranch : worklist) {
218 auto loc = condBranch.getLoc();
219 Block *block = condBranch->getBlock();
220 auto *newTrueBranch = rewriter.
splitBlock(block, block->
end());
221 auto *newFalseBranch = rewriter.
splitBlock(block, block->
end());
222 insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
223 condBranch.getTrueDestOperands());
224 insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
225 condBranch.getFalseDestOperands());
227 condBranch.getFalseDestOperandsMutable().clear();
228 condBranch.getTrueDestOperandsMutable().clear();
229 condBranch.setSuccessor(newTrueBranch, 0);
230 condBranch.setSuccessor(newFalseBranch, 1);
247void insertCopiesAtBranches(
IRRewriter &rewriter,
248 FunctionOpInterface function) {
249 for (
Block &block : function.getBlocks()) {
251 if (!isa<cf::BranchOp>(terminator))
257 CopyTileOp::create(rewriter, terminator->
getLoc(), operand.get());
272void preprocessForTileAllocation(
IRRewriter &rewriter,
273 FunctionOpInterface function) {
274 splitCondBranches(rewriter, function);
275 insertCopiesAtBranches(rewriter, function);
284 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
285 llvm::IntervalMapHalfOpenInfo<unsigned>>;
286 using Allocator = RangeSet::Allocator;
288 static constexpr uint8_t kValidLiveRange = 0xff;
290 LiveRange(Allocator &allocator)
291 : ranges(std::make_unique<RangeSet>(allocator)) {}
294 bool overlaps(LiveRange
const &otherRange)
const {
295 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
301 bool overlaps(uint64_t point)
const {
302 return ranges->lookup(point) == kValidLiveRange;
306 void unionWith(LiveRange
const &otherRange) {
307 for (
auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
309 ranges->insert(it.start(), it.stop(), kValidLiveRange);
310 values.set_union(otherRange.values);
314 void insert(Value value,
unsigned start,
unsigned end) {
315 values.insert(value);
317 ranges->insert(start, end, kValidLiveRange);
320 bool empty()
const {
return ranges->empty(); }
321 unsigned start()
const {
return ranges->start(); }
322 unsigned end()
const {
return ranges->stop(); }
323 bool operator<(LiveRange
const &other)
const {
324 return start() < other.start();
327 ArmSMETileType getTileType()
const {
336 std::unique_ptr<RangeSet> ranges;
339 std::optional<unsigned> tileId;
346FailureOr<DenseMap<Operation *, unsigned>>
347generateOperationNumbering(FunctionOpInterface function) {
352 for (
Block *block : blocks) {
356 op.walk([&](ArmSMETileOpInterface nestedOp) ->
WalkResult {
357 if (&op != nestedOp.getOperation())
362 return op.emitError(
"ArmSME tile allocation requires flattened control "
363 "flow; run -convert-scf-to-cf before this pass "
364 "(e.g. via convert-arm-sme-to-llvm pipeline)");
366 operationToIndexMap.try_emplace(&op,
index++);
370 return operationToIndexMap;
376 LiveRange::Allocator &liveRangeAllocator,
377 Liveness &liveness, FunctionOpInterface function) {
378 assert(!operationToIndexMap.empty() &&
"expected operation numbering");
384 auto defineOrUpdateValueLiveRange = [&](
Value value,
Operation *firstUseOrDef,
386 bool liveAtBlockEntry =
false) {
390 auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
391 LiveRange &valueLiveRange = it->second;
392 auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
394 unsigned startOpIdx =
395 operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
396 unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
397 valueLiveRange.insert(value, startOpIdx, endOpIdx);
400 for (
Block &block : function.getBlocks()) {
404 defineOrUpdateValueLiveRange(argument, &block.
front(), *livenessInfo,
407 for (
Value liveIn : livenessInfo->
in())
408 defineOrUpdateValueLiveRange(liveIn, &block.
front(), *livenessInfo,
413 defineOrUpdateValueLiveRange(
result, &op, *livenessInfo);
421static void forEachPredecessorTileValue(
BlockArgument blockArg,
427 .Case([&](cf::BranchOp branch) {
428 Value predecessorOperand = branch.getDestOperands()[argNumber];
429 callback(predecessorOperand);
431 .Case([&](cf::CondBranchOp condBranch) {
432 if (condBranch.getFalseDest() == block) {
433 Value predecessorOperand =
434 condBranch.getFalseDestOperands()[argNumber];
435 callback(predecessorOperand);
437 if (condBranch.getTrueDest() == block) {
438 Value predecessorOperand =
439 condBranch.getTrueDestOperands()[argNumber];
440 callback(predecessorOperand);
450 for (
auto &[value, liveRange] : initialLiveRanges) {
451 liveRanges.insert({value, &liveRange});
457 auto mergeValuesIfNonOverlapping = [&](
Value a,
Value b) {
458 LiveRange *aLiveRange = liveRanges.at(a);
459 LiveRange *bLiveRange = liveRanges.at(
b);
460 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
461 aLiveRange->unionWith(*bLiveRange);
462 for (
Value value : bLiveRange->values)
463 liveRanges[value] = aLiveRange;
468 auto unifyDefinitionsWithOperands = [&](
Value value) {
469 auto armSMEOp = value.
getDefiningOp<ArmSMETileOpInterface>();
472 for (
auto operand : armSMEOp->getOperands()) {
474 mergeValuesIfNonOverlapping(value, operand);
479 auto unifyBlockArgumentsWithPredecessors = [&](
Value value) {
480 auto blockArg = dyn_cast<BlockArgument>(value);
483 forEachPredecessorTileValue(blockArg, [&](
Value predecessorTile) {
484 mergeValuesIfNonOverlapping(blockArg, predecessorTile);
488 auto applyRule = [&](
auto rule) {
489 llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
493 applyRule(unifyBlockArgumentsWithPredecessors);
494 applyRule(unifyDefinitionsWithOperands);
498 for (
auto [_, liveRange] : liveRanges) {
499 if (!liveRange->empty())
500 uniqueLiveRanges.insert(liveRange);
504 auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
505 llvm::sort(coalescedLiveRanges,
506 [](LiveRange *a, LiveRange *
b) {
return *a < *
b; });
507 return std::move(coalescedLiveRanges);
512template <
typename OverlappingRangesIterator>
514chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
515 LiveRange *newRange) {
517 auto isTrivialSpill = [&](LiveRange &allocatedRange) {
519 newRange->getTileType()) &&
520 allocatedRange.values.size() == 1 &&
522 allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
524 if (isTrivialSpill(*newRange))
526 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
527 if (trivialSpill != overlappingRanges.end())
528 return &*trivialSpill;
531 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &
b) {
535 LiveRange &latestEndingLiveRange =
536 *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
537 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
538 return &latestEndingLiveRange;
543void allocateTilesToLiveRanges(
545 TileAllocator tileAllocator;
556 for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
557 auto currentPoint = nextRange->start();
559 activeRanges.remove_if([&](LiveRange *activeRange) {
561 if (activeRange->end() <= currentPoint) {
562 tileAllocator.releaseTileId(activeRange->getTileType(),
563 *activeRange->tileId);
567 if (!activeRange->overlaps(currentPoint)) {
568 tileAllocator.releaseTileId(activeRange->getTileType(),
569 *activeRange->tileId);
570 inactiveRanges.insert(activeRange);
576 inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
578 if (inactiveRange->end() <= currentPoint) {
582 if (inactiveRange->overlaps(currentPoint)) {
583 tileAllocator.acquireTileId(inactiveRange->getTileType(),
584 *inactiveRange->tileId);
585 activeRanges.insert(inactiveRange);
595 for (LiveRange *inactiveRange : inactiveRanges) {
596 if (inactiveRange->overlaps(*nextRange)) {
599 tileAllocator.acquireTileId(inactiveRange->getTileType(),
600 *inactiveRange->tileId);
601 overlappingInactiveRanges.push_back(inactiveRange);
606 auto rangeTileType = nextRange->getTileType();
607 auto tileId = tileAllocator.allocateTileId(rangeTileType);
608 if (succeeded(tileId)) {
609 nextRange->tileId = *tileId;
612 auto allOverlappingRanges = llvm::concat<LiveRange>(
613 llvm::make_pointee_range(activeRanges.getArrayRef()),
614 llvm::make_pointee_range(overlappingInactiveRanges));
616 LiveRange *rangeToSpill =
617 chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
618 if (rangeToSpill != nextRange) {
620 tileAllocator.releaseTileId(rangeToSpill->getTileType(),
621 *rangeToSpill->tileId);
623 nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
625 if (!activeRanges.remove(rangeToSpill)) {
626 bool removed = inactiveRanges.remove(rangeToSpill);
627 assert(removed &&
"expected a range to be removed!");
631 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
636 activeRanges.insert(nextRange);
639 for (LiveRange *range : overlappingInactiveRanges) {
641 tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
648 IntegerAttr tileIdAttr) {
649 if (
auto tileOp = value.
getDefiningOp<ArmSMETileOpInterface>())
650 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
652 if (
auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
655 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
661LogicalResult assignTileIdsAndResolveTrivialConflicts(
662 IRRewriter &rewriter, FunctionOpInterface function,
664 for (LiveRange
const *liveRange : allocatedLiveRanges) {
666 auto isAllocatedToSameTile = [&](
Value value) {
667 if (
auto tileOp = value.
getDefiningOp<ArmSMETileOpInterface>();
668 tileOp && tileOp.getTileId() == tileIdAttr)
670 return liveRange->values.contains(value);
674 auto foldRedundantCopies = [&](
Value value) -> LogicalResult {
676 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
684 auto validateBlockArguments = [&](
Value value) {
685 auto blockArg = dyn_cast<BlockArgument>(value);
690 bool tileMismatch =
false;
691 forEachPredecessorTileValue(blockArg, [&](
Value predecessorTile) {
694 if (!isAllocatedToSameTile(predecessorTile)) {
696 "block argument not allocated to the same SME virtial tile as "
705 auto resolveTrivialTileConflicts = [&](
Value value) -> LogicalResult {
708 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
714 tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
717 tileOp.emitOpError(
"tile operand allocated to different SME "
718 "virtial tile (move required)");
719 error.attachNote(tileOperand->get().getLoc())
720 <<
"tile operand is: " << tileOperand->get();
725 auto clonedOp = operandTileOp.clone();
727 [&] { clonedOp.setTileId(tileOp.getTileId()); });
728 rewriter.
insert(clonedOp);
729 if (isa<CopyTileOp>(tileOp)) {
731 clonedOp->getResult(0));
734 tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
739 for (
Value value : liveRange->values) {
741 assignTileIdToValue(rewriter, value, tileIdAttr);
744 if (succeeded(foldRedundantCopies(value)))
748 if (
failed(validateBlockArguments(value)))
752 if (
failed(resolveTrivialTileConflicts(value)))
762 FunctionOpInterface function) {
763 llvm::errs() <<
"SME Tile Liveness: @" << function.getName()
764 <<
"\nKey:\nS - Start\nE - End\n| - Live\n";
765 for (
auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
766 llvm::errs() <<
"^bb" << blockIdx <<
":\n";
768 unsigned operationIndex = operationToIndexMap.at(&op);
769 for (LiveRange
const *range : liveRanges) {
771 for (
auto it = range->ranges->begin(); it != range->ranges->end();
773 if (it.start() == operationIndex)
774 liveness = (liveness ==
'E' ?
'|' :
'S');
775 else if (it.stop() == operationIndex)
776 liveness = (liveness ==
'S' ?
'|' :
'E');
777 else if (operationIndex >= it.start() && operationIndex < it.stop())
780 llvm::errs() << liveness;
782 llvm::errs() <<
' ' << op.getName() <<
'\n';
785 llvm::errs() <<
"==========\n";
788struct TestTileAllocationPass
789 :
public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
790 using TestTileAllocationBase::TestTileAllocationBase;
791 void runOnOperation()
override {
792 FunctionOpInterface function = getOperation();
793 if (preprocessOnly) {
794 IRRewriter rewriter(function);
795 return preprocessForTileAllocation(rewriter, function);
805 if (function.empty()) {
810 LiveRange::Allocator liveRangeAllocator;
814 preprocessForTileAllocation(rewriter, function);
818 auto maybeOperationToIndexMap = generateOperationNumbering(function);
819 if (failed(maybeOperationToIndexMap))
821 auto &operationToIndexMap = *maybeOperationToIndexMap;
822 auto initialLiveRanges = gatherTileLiveRanges(
823 operationToIndexMap, liveRangeAllocator, liveness, function);
824 if (initialLiveRanges.empty())
829 auto nonEmpty = llvm::make_filter_range(
830 llvm::make_second_range(initialLiveRanges),
831 [&](LiveRange
const &liveRange) {
return !liveRange.empty(); });
832 auto initialRanges = llvm::map_to_vector(
833 nonEmpty, [](LiveRange
const &liveRange) {
return &liveRange; });
834 llvm::sort(initialRanges,
835 [](LiveRange
const *a, LiveRange
const *
b) {
return *a < *
b; });
836 llvm::errs() <<
"\n========== Initial Live Ranges:\n";
837 dumpLiveRanges(operationToIndexMap, initialRanges, function);
843 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
846 llvm::errs() <<
"\n========== Coalesced Live Ranges:\n";
847 dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
851 allocateTilesToLiveRanges(coalescedLiveRanges);
854 if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
855 coalescedLiveRanges))) {
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IntegerAttr getI32IntegerAttr(int32_t value)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Represents an analysis for computing liveness information from a given top-level operation.
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
@ LLVM_MARK_AS_BITMASK_ENUM
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
static constexpr unsigned kInMemoryTileIdBase
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
bool operator<(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref