55#include "llvm/ADT/IntervalMap.h"
56#include "llvm/ADT/SmallVectorExtras.h"
57#include "llvm/ADT/TypeSwitch.h"
60#define GEN_PASS_DEF_TESTTILEALLOCATION
61#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
69enum class TileMask :
unsigned {
115 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
116 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
117 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
118 TileMask::kZA2S, TileMask::kZA3S};
119 static constexpr std::array ZA_D_MASKS = {
120 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
121 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
122 static constexpr std::array ZA_Q_MASKS = {
123 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
124 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
125 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
126 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
128 case ArmSMETileType::ZAB:
130 case ArmSMETileType::ZAH:
132 case ArmSMETileType::ZAS:
134 case ArmSMETileType::ZAD:
136 case ArmSMETileType::ZAQ:
139 llvm_unreachable(
"unknown type in getMasks");
145 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
146 auto masks = getMasks(tileType);
147 for (
auto [tileId, tileMask] : llvm::enumerate(masks)) {
148 if ((tilesInUse & tileMask) == TileMask::kNone) {
149 tilesInUse |= tileMask;
157 void acquireTileId(ArmSMETileType tileType,
unsigned tileId) {
158 TileMask tileMask = getMasks(tileType)[tileId];
159 assert((tilesInUse & tileMask) == TileMask::kNone &&
160 "cannot acquire allocated tile!");
161 tilesInUse |= tileMask;
165 void releaseTileId(ArmSMETileType tileType,
unsigned tileId) {
166 TileMask tileMask = getMasks(tileType)[tileId];
167 assert((tilesInUse & tileMask) == tileMask &&
168 "cannot release unallocated tile!");
169 tilesInUse ^= tileMask;
173 unsigned allocateInMemoryTileId() {
177 return nextInMemoryTileId++;
181 TileMask tilesInUse = TileMask::kNone;
202void splitCondBranches(
IRRewriter &rewriter, FunctionOpInterface function) {
204 function.walk([&](cf::CondBranchOp condBranch) {
205 if (llvm::any_of(condBranch->getOperands(), [&](
Value value) {
206 return isValidSMETileVectorType(value.getType());
208 worklist.push_back(condBranch);
214 cf::BranchOp::create(rewriter, loc, dest, args);
217 for (
auto condBranch : worklist) {
218 auto loc = condBranch.getLoc();
219 Block *block = condBranch->getBlock();
220 auto *newTrueBranch = rewriter.
splitBlock(block, block->
end());
222 insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
223 condBranch.getTrueDestOperands());
224 insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
225 condBranch.getFalseDestOperands());
227 condBranch.getFalseDestOperandsMutable().clear();
228 condBranch.getTrueDestOperandsMutable().clear();
229 condBranch.setSuccessor(newTrueBranch, 0);
230 condBranch.setSuccessor(newFalseBranch, 1);
247void insertCopiesAtBranches(
IRRewriter &rewriter,
248 FunctionOpInterface function) {
249 for (
Block &block : function.getBlocks()) {
251 if (!isa<cf::BranchOp>(terminator))
257 CopyTileOp::create(rewriter, terminator->
getLoc(), operand.get());
273 FunctionOpInterface function) {
274 splitCondBranches(rewriter, function);
275 insertCopiesAtBranches(rewriter, function);
284 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
285 llvm::IntervalMapHalfOpenInfo<unsigned>>;
286 using Allocator = RangeSet::Allocator;
288 static constexpr uint8_t kValidLiveRange = 0xff;
290 LiveRange(Allocator &allocator)
291 : ranges(std::make_unique<RangeSet>(allocator)) {}
294 bool overlaps(LiveRange
const &otherRange)
const {
295 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
301 bool overlaps(uint64_t point)
const {
302 return ranges->lookup(point) == kValidLiveRange;
306 void unionWith(LiveRange
const &otherRange) {
307 for (
auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
309 ranges->insert(it.start(), it.stop(), kValidLiveRange);
310 values.set_union(otherRange.values);
314 void insert(Value value,
unsigned start,
unsigned end) {
315 values.insert(value);
317 ranges->insert(start, end, kValidLiveRange);
320 bool empty()
const {
return ranges->empty(); }
321 unsigned start()
const {
return ranges->start(); }
322 unsigned end()
const {
return ranges->stop(); }
323 bool operator<(LiveRange
const &other)
const {
324 return start() < other.start();
327 ArmSMETileType getTileType()
const {
336 std::unique_ptr<RangeSet> ranges;
339 std::optional<unsigned> tileId;
347generateOperationNumbering(FunctionOpInterface function) {
352 for (
Block *block : blocks) {
358 "ArmSME tile allocation does not support nested regions");
361 operationToIndexMap.try_emplace(&op,
index++);
364 return operationToIndexMap;
370 LiveRange::Allocator &liveRangeAllocator,
371 Liveness &liveness, FunctionOpInterface function) {
372 assert(!operationToIndexMap.empty() &&
"expected operation numbering");
378 auto defineOrUpdateValueLiveRange = [&](
Value value,
Operation *firstUseOrDef,
380 bool liveAtBlockEntry =
false) {
384 auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
385 LiveRange &valueLiveRange = it->second;
386 auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
388 unsigned startOpIdx =
389 operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
390 unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
391 valueLiveRange.insert(value, startOpIdx, endOpIdx);
394 for (
Block &block : function.getBlocks()) {
398 defineOrUpdateValueLiveRange(argument, &block.
front(), *livenessInfo,
401 for (
Value liveIn : livenessInfo->
in())
402 defineOrUpdateValueLiveRange(liveIn, &block.
front(), *livenessInfo,
407 defineOrUpdateValueLiveRange(
result, &op, *livenessInfo);
415static void forEachPredecessorTileValue(
BlockArgument blockArg,
421 .Case([&](cf::BranchOp branch) {
422 Value predecessorOperand = branch.getDestOperands()[argNumber];
423 callback(predecessorOperand);
425 .Case([&](cf::CondBranchOp condBranch) {
426 if (condBranch.getFalseDest() == block) {
427 Value predecessorOperand =
428 condBranch.getFalseDestOperands()[argNumber];
429 callback(predecessorOperand);
431 if (condBranch.getTrueDest() == block) {
432 Value predecessorOperand =
433 condBranch.getTrueDestOperands()[argNumber];
434 callback(predecessorOperand);
444 for (
auto &[value, liveRange] : initialLiveRanges) {
445 liveRanges.insert({value, &liveRange});
451 auto mergeValuesIfNonOverlapping = [&](
Value a,
Value b) {
452 LiveRange *aLiveRange = liveRanges.at(a);
453 LiveRange *bLiveRange = liveRanges.at(
b);
454 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
455 aLiveRange->unionWith(*bLiveRange);
456 for (
Value value : bLiveRange->values)
457 liveRanges[value] = aLiveRange;
462 auto unifyDefinitionsWithOperands = [&](
Value value) {
466 for (
auto operand : armSMEOp->getOperands()) {
468 mergeValuesIfNonOverlapping(value, operand);
473 auto unifyBlockArgumentsWithPredecessors = [&](
Value value) {
474 auto blockArg = dyn_cast<BlockArgument>(value);
477 forEachPredecessorTileValue(blockArg, [&](
Value predecessorTile) {
478 mergeValuesIfNonOverlapping(blockArg, predecessorTile);
482 auto applyRule = [&](
auto rule) {
483 llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
487 applyRule(unifyBlockArgumentsWithPredecessors);
488 applyRule(unifyDefinitionsWithOperands);
492 for (
auto [_, liveRange] : liveRanges) {
493 if (!liveRange->empty())
494 uniqueLiveRanges.insert(liveRange);
498 auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
499 llvm::sort(coalescedLiveRanges,
500 [](LiveRange *a, LiveRange *
b) {
return *a < *
b; });
501 return std::move(coalescedLiveRanges);
506template <
typename OverlappingRangesIterator>
508chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
509 LiveRange *newRange) {
511 auto isTrivialSpill = [&](LiveRange &allocatedRange) {
513 newRange->getTileType()) &&
514 allocatedRange.values.size() == 1 &&
518 if (isTrivialSpill(*newRange))
520 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
521 if (trivialSpill != overlappingRanges.end())
522 return &*trivialSpill;
525 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &
b) {
529 LiveRange &latestEndingLiveRange =
530 *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
531 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
532 return &latestEndingLiveRange;
537void allocateTilesToLiveRanges(
539 TileAllocator tileAllocator;
550 for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
551 auto currentPoint = nextRange->start();
553 activeRanges.remove_if([&](LiveRange *activeRange) {
555 if (activeRange->end() <= currentPoint) {
556 tileAllocator.releaseTileId(activeRange->getTileType(),
557 *activeRange->tileId);
561 if (!activeRange->overlaps(currentPoint)) {
562 tileAllocator.releaseTileId(activeRange->getTileType(),
563 *activeRange->tileId);
564 inactiveRanges.insert(activeRange);
570 inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
572 if (inactiveRange->end() <= currentPoint) {
576 if (inactiveRange->overlaps(currentPoint)) {
577 tileAllocator.acquireTileId(inactiveRange->getTileType(),
578 *inactiveRange->tileId);
579 activeRanges.insert(inactiveRange);
589 for (LiveRange *inactiveRange : inactiveRanges) {
590 if (inactiveRange->overlaps(*nextRange)) {
593 tileAllocator.acquireTileId(inactiveRange->getTileType(),
594 *inactiveRange->tileId);
595 overlappingInactiveRanges.push_back(inactiveRange);
600 auto rangeTileType = nextRange->getTileType();
601 auto tileId = tileAllocator.allocateTileId(rangeTileType);
602 if (succeeded(tileId)) {
603 nextRange->tileId = *tileId;
606 auto allOverlappingRanges = llvm::concat<LiveRange>(
607 llvm::make_pointee_range(activeRanges.getArrayRef()),
608 llvm::make_pointee_range(overlappingInactiveRanges));
610 LiveRange *rangeToSpill =
611 chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
612 if (rangeToSpill != nextRange) {
614 tileAllocator.releaseTileId(rangeToSpill->getTileType(),
615 *rangeToSpill->tileId);
617 nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
619 if (!activeRanges.remove(rangeToSpill)) {
620 bool removed = inactiveRanges.remove(rangeToSpill);
621 assert(removed &&
"expected a range to be removed!");
625 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
630 activeRanges.insert(nextRange);
633 for (LiveRange *range : overlappingInactiveRanges) {
635 tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
642 IntegerAttr tileIdAttr) {
644 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
646 if (
auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
649 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
655LogicalResult assignTileIdsAndResolveTrivialConflicts(
656 IRRewriter &rewriter, FunctionOpInterface function,
658 for (LiveRange
const *liveRange : allocatedLiveRanges) {
660 auto isAllocatedToSameTile = [&](
Value value) {
662 tileOp && tileOp.
getTileId() == tileIdAttr)
664 return liveRange->values.contains(value);
668 auto foldRedundantCopies = [&](
Value value) -> LogicalResult {
670 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
678 auto validateBlockArguments = [&](
Value value) {
679 auto blockArg = dyn_cast<BlockArgument>(value);
684 bool tileMismatch =
false;
685 forEachPredecessorTileValue(blockArg, [&](
Value predecessorTile) {
688 if (!isAllocatedToSameTile(predecessorTile)) {
690 "block argument not allocated to the same SME virtial tile as "
699 auto resolveTrivialTileConflicts = [&](
Value value) -> LogicalResult {
702 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
711 tileOp.emitOpError(
"tile operand allocated to different SME "
712 "virtial tile (move required)");
713 error.attachNote(tileOperand->get().getLoc())
714 <<
"tile operand is: " << tileOperand->get();
719 auto clonedOp = operandTileOp.clone();
721 [&] { clonedOp.setTileId(tileOp.getTileId()); });
722 rewriter.
insert(clonedOp);
723 if (isa<CopyTileOp>(tileOp)) {
725 clonedOp->getResult(0));
728 tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
733 for (
Value value : liveRange->values) {
735 assignTileIdToValue(rewriter, value, tileIdAttr);
738 if (succeeded(foldRedundantCopies(value)))
742 if (
failed(validateBlockArguments(value)))
746 if (
failed(resolveTrivialTileConflicts(value)))
756 FunctionOpInterface function) {
757 llvm::errs() <<
"SME Tile Liveness: @" << function.getName()
758 <<
"\nKey:\nS - Start\nE - End\n| - Live\n";
759 for (
auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
760 llvm::errs() <<
"^bb" << blockIdx <<
":\n";
762 unsigned operationIndex = operationToIndexMap.at(&op);
763 for (LiveRange
const *range : liveRanges) {
765 for (
auto it = range->ranges->begin(); it != range->ranges->end();
767 if (it.start() == operationIndex)
768 liveness = (liveness ==
'E' ?
'|' :
'S');
769 else if (it.stop() == operationIndex)
770 liveness = (liveness ==
'S' ?
'|' :
'E');
771 else if (operationIndex >= it.start() && operationIndex < it.stop())
774 llvm::errs() << liveness;
776 llvm::errs() <<
' ' << op.getName() <<
'\n';
779 llvm::errs() <<
"==========\n";
782struct TestTileAllocationPass
784 using TestTileAllocationBase::TestTileAllocationBase;
785 void runOnOperation()
override {
786 FunctionOpInterface function = getOperation();
787 if (preprocessOnly) {
788 IRRewriter rewriter(function);
789 return preprocessForTileAllocation(rewriter, function);
799 if (function.empty()) {
804 LiveRange::Allocator liveRangeAllocator;
808 preprocessForTileAllocation(rewriter, function);
812 auto operationToIndexMap = generateOperationNumbering(function);
813 auto initialLiveRanges = gatherTileLiveRanges(
814 operationToIndexMap, liveRangeAllocator, liveness, function);
815 if (initialLiveRanges.empty())
820 auto nonEmpty = llvm::make_filter_range(
821 llvm::make_second_range(initialLiveRanges),
822 [&](LiveRange
const &liveRange) {
return !liveRange.empty(); });
823 auto initialRanges = llvm::map_to_vector(
824 nonEmpty, [](LiveRange
const &liveRange) {
return &liveRange; });
825 llvm::sort(initialRanges,
826 [](LiveRange
const *a, LiveRange
const *
b) {
return *a < *
b; });
827 llvm::errs() <<
"\n========== Initial Live Ranges:\n";
828 dumpLiveRanges(operationToIndexMap, initialRanges, function);
834 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
837 llvm::errs() <<
"\n========== Coalesced Live Ranges:\n";
838 dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
842 allocateTilesToLiveRanges(coalescedLiveRanges);
845 if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
846 coalescedLiveRanges))) {
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IntegerAttr getI32IntegerAttr(int32_t value)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Represents an analysis for computing liveness information from a given top-level operation.
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
mlir::IntegerAttr getTileId()
Returns the tile ID assigned to this operation.
@ LLVM_MARK_AS_BITMASK_ENUM
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
static constexpr unsigned kInMemoryTileIdBase
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
bool operator<(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref