55 #include "llvm/ADT/IntervalMap.h"
56 #include "llvm/ADT/TypeSwitch.h"
59 #define GEN_PASS_DEF_TESTTILEALLOCATION
60 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
68 enum class TileMask : unsigned {
114 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
115 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
116 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
117 TileMask::kZA2S, TileMask::kZA3S};
118 static constexpr std::array ZA_D_MASKS = {
119 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
120 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
121 static constexpr std::array ZA_Q_MASKS = {
122 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
123 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
124 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
125 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
127 case ArmSMETileType::ZAB:
129 case ArmSMETileType::ZAH:
131 case ArmSMETileType::ZAS:
133 case ArmSMETileType::ZAD:
135 case ArmSMETileType::ZAQ:
138 llvm_unreachable(
"unknown type in getMasks");
141 class TileAllocator {
144 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
145 auto masks = getMasks(tileType);
147 if ((tilesInUse & tileMask) == TileMask::kNone) {
148 tilesInUse |= tileMask;
156 void acquireTileId(ArmSMETileType tileType,
unsigned tileId) {
157 TileMask tileMask = getMasks(tileType)[tileId];
158 assert((tilesInUse & tileMask) == TileMask::kNone &&
159 "cannot acquire allocated tile!");
160 tilesInUse |= tileMask;
164 void releaseTileId(ArmSMETileType tileType,
unsigned tileId) {
165 TileMask tileMask = getMasks(tileType)[tileId];
166 assert((tilesInUse & tileMask) == tileMask &&
167 "cannot release unallocated tile!");
168 tilesInUse ^= tileMask;
172 unsigned allocateInMemoryTileId() {
176 return nextInMemoryTileId++;
180 TileMask tilesInUse = TileMask::kNone;
201 void splitCondBranches(
IRRewriter &rewriter, FunctionOpInterface
function) {
203 function.walk([&](cf::CondBranchOp condBranch) {
204 if (llvm::any_of(condBranch->getOperands(), [&](
Value value) {
205 return isValidSMETileVectorType(value.getType());
207 worklist.push_back(condBranch);
213 cf::BranchOp::create(rewriter, loc, dest, args);
216 for (
auto condBranch : worklist) {
217 auto loc = condBranch.getLoc();
218 Block *block = condBranch->getBlock();
219 auto newTrueBranch = rewriter.
splitBlock(block, block->
end());
220 auto newFalseBranch = rewriter.
splitBlock(block, block->
end());
221 insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
222 condBranch.getTrueDestOperands());
223 insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
224 condBranch.getFalseDestOperands());
226 condBranch.getFalseDestOperandsMutable().clear();
227 condBranch.getTrueDestOperandsMutable().clear();
228 condBranch.setSuccessor(newTrueBranch, 0);
229 condBranch.setSuccessor(newFalseBranch, 1);
246 void insertCopiesAtBranches(
IRRewriter &rewriter,
247 FunctionOpInterface
function) {
248 for (
Block &block :
function.getBlocks()) {
250 if (!isa<cf::BranchOp>(terminator))
256 CopyTileOp::create(rewriter, terminator->
getLoc(), operand.get());
271 void preprocessForTileAllocation(
IRRewriter &rewriter,
272 FunctionOpInterface
function) {
273 splitCondBranches(rewriter,
function);
274 insertCopiesAtBranches(rewriter,
function);
283 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
284 llvm::IntervalMapHalfOpenInfo<unsigned>>;
285 using Allocator = RangeSet::Allocator;
287 static constexpr uint8_t kValidLiveRange = 0xff;
289 LiveRange(Allocator &allocator)
290 : ranges(std::make_unique<RangeSet>(allocator)) {}
293 bool overlaps(LiveRange
const &otherRange)
const {
294 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
300 bool overlaps(uint64_t point)
const {
301 return ranges->lookup(point) == kValidLiveRange;
305 void unionWith(LiveRange
const &otherRange) {
306 for (
auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
308 ranges->insert(it.start(), it.stop(), kValidLiveRange);
309 values.set_union(otherRange.values);
313 void insert(
Value value,
unsigned start,
unsigned end) {
314 values.insert(value);
316 ranges->insert(start, end, kValidLiveRange);
319 bool empty()
const {
return ranges->empty(); }
320 unsigned start()
const {
return ranges->start(); }
321 unsigned end()
const {
return ranges->stop(); }
322 bool operator<(LiveRange
const &other)
const {
323 return start() < other.start();
326 ArmSMETileType getTileType()
const {
335 std::unique_ptr<RangeSet> ranges;
338 std::optional<unsigned> tileId;
346 generateOperationNumbering(FunctionOpInterface
function) {
351 for (
Block *block : blocks) {
355 op.walk([&](ArmSMETileOpInterface nestedOp) {
356 assert(&op == nestedOp.getOperation() &&
357 "ArmSME tile allocation does not support nested regions");
360 operationToIndexMap.try_emplace(&op, index++);
363 return operationToIndexMap;
369 LiveRange::Allocator &liveRangeAllocator,
370 Liveness &liveness, FunctionOpInterface
function) {
371 assert(!operationToIndexMap.empty() &&
"expected operation numbering");
377 auto defineOrUpdateValueLiveRange = [&](
Value value,
Operation *firstUseOrDef,
379 bool liveAtBlockEntry =
false) {
383 auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
384 LiveRange &valueLiveRange = it->second;
385 auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
387 unsigned startOpIdx =
388 operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
389 unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
390 valueLiveRange.insert(value, startOpIdx, endOpIdx);
393 for (
Block &block :
function.getBlocks()) {
397 defineOrUpdateValueLiveRange(argument, &block.
front(), *livenessInfo,
400 for (
Value liveIn : livenessInfo->
in())
401 defineOrUpdateValueLiveRange(liveIn, &block.
front(), *livenessInfo,
405 for (
Value result : op.getResults())
406 defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
414 static void forEachPredecessorTileValue(
BlockArgument blockArg,
420 .Case<cf::BranchOp>([&](
auto branch) {
421 Value predecessorOperand = branch.getDestOperands()[argNumber];
422 callback(predecessorOperand);
424 .Case<cf::CondBranchOp>([&](
auto condBranch) {
425 if (condBranch.getFalseDest() == block) {
426 Value predecessorOperand =
427 condBranch.getFalseDestOperands()[argNumber];
428 callback(predecessorOperand);
430 if (condBranch.getTrueDest() == block) {
431 Value predecessorOperand =
432 condBranch.getTrueDestOperands()[argNumber];
433 callback(predecessorOperand);
443 for (
auto &[value, liveRange] : initialLiveRanges) {
444 liveRanges.insert({value, &liveRange});
450 auto mergeValuesIfNonOverlapping = [&](
Value a,
Value b) {
451 LiveRange *aLiveRange = liveRanges.at(a);
452 LiveRange *bLiveRange = liveRanges.at(b);
453 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
454 aLiveRange->unionWith(*bLiveRange);
455 for (
Value value : bLiveRange->values)
456 liveRanges[value] = aLiveRange;
461 auto unifyDefinitionsWithOperands = [&](
Value value) {
462 auto armSMEOp = value.
getDefiningOp<ArmSMETileOpInterface>();
465 for (
auto operand : armSMEOp->getOperands()) {
467 mergeValuesIfNonOverlapping(value, operand);
472 auto unifyBlockArgumentsWithPredecessors = [&](
Value value) {
473 auto blockArg = dyn_cast<BlockArgument>(value);
476 forEachPredecessorTileValue(blockArg, [&](
Value predecessorTile) {
477 mergeValuesIfNonOverlapping(blockArg, predecessorTile);
481 auto applyRule = [&](
auto rule) {
482 llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
486 applyRule(unifyBlockArgumentsWithPredecessors);
487 applyRule(unifyDefinitionsWithOperands);
491 for (
auto [_, liveRange] : liveRanges) {
492 if (!liveRange->empty())
493 uniqueLiveRanges.insert(liveRange);
497 auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
498 llvm::sort(coalescedLiveRanges,
499 [](LiveRange *a, LiveRange *b) {
return *a < *b; });
500 return std::move(coalescedLiveRanges);
505 template <
typename OverlappingRangesIterator>
507 chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
508 LiveRange *newRange) {
510 auto isTrivialSpill = [&](LiveRange &allocatedRange) {
512 newRange->getTileType()) &&
513 allocatedRange.values.size() == 1 &&
515 allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
517 if (isTrivialSpill(*newRange))
519 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
520 if (trivialSpill != overlappingRanges.end())
521 return &*trivialSpill;
524 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
528 LiveRange &latestEndingLiveRange =
529 *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
530 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
531 return &latestEndingLiveRange;
536 void allocateTilesToLiveRanges(
538 TileAllocator tileAllocator;
549 for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
550 auto currentPoint = nextRange->start();
552 activeRanges.remove_if([&](LiveRange *activeRange) {
554 if (activeRange->end() <= currentPoint) {
555 tileAllocator.releaseTileId(activeRange->getTileType(),
556 *activeRange->tileId);
560 if (!activeRange->overlaps(currentPoint)) {
561 tileAllocator.releaseTileId(activeRange->getTileType(),
562 *activeRange->tileId);
563 inactiveRanges.insert(activeRange);
569 inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
571 if (inactiveRange->end() <= currentPoint) {
575 if (inactiveRange->overlaps(currentPoint)) {
576 tileAllocator.acquireTileId(inactiveRange->getTileType(),
577 *inactiveRange->tileId);
578 activeRanges.insert(inactiveRange);
588 for (LiveRange *inactiveRange : inactiveRanges) {
589 if (inactiveRange->overlaps(*nextRange)) {
592 tileAllocator.acquireTileId(inactiveRange->getTileType(),
593 *inactiveRange->tileId);
594 overlappingInactiveRanges.push_back(inactiveRange);
599 auto rangeTileType = nextRange->getTileType();
600 auto tileId = tileAllocator.allocateTileId(rangeTileType);
601 if (succeeded(tileId)) {
602 nextRange->tileId = *tileId;
605 auto allOverlappingRanges = llvm::concat<LiveRange>(
606 llvm::make_pointee_range(activeRanges.getArrayRef()),
607 llvm::make_pointee_range(overlappingInactiveRanges));
609 LiveRange *rangeToSpill =
610 chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
611 if (rangeToSpill != nextRange) {
613 tileAllocator.releaseTileId(rangeToSpill->getTileType(),
614 *rangeToSpill->tileId);
616 nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
618 if (!activeRanges.remove(rangeToSpill)) {
619 bool removed = inactiveRanges.remove(rangeToSpill);
620 assert(removed &&
"expected a range to be removed!");
624 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
629 activeRanges.insert(nextRange);
632 for (LiveRange *range : overlappingInactiveRanges) {
634 tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
641 IntegerAttr tileIdAttr) {
642 if (
auto tileOp = value.
getDefiningOp<ArmSMETileOpInterface>())
643 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
645 if (
auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
648 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
654 LogicalResult assignTileIdsAndResolveTrivialConflicts(
655 IRRewriter &rewriter, FunctionOpInterface
function,
657 for (LiveRange
const *liveRange : allocatedLiveRanges) {
659 auto isAllocatedToSameTile = [&](
Value value) {
660 if (
auto tileOp = value.
getDefiningOp<ArmSMETileOpInterface>();
661 tileOp && tileOp.getTileId() == tileIdAttr)
663 return liveRange->values.contains(value);
667 auto foldRedundantCopies = [&](
Value value) -> LogicalResult {
669 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
677 auto validateBlockArguments = [&](
Value value) {
678 auto blockArg = dyn_cast<BlockArgument>(value);
683 bool tileMismatch =
false;
684 forEachPredecessorTileValue(blockArg, [&](
Value predecessorTile) {
687 if (!isAllocatedToSameTile(predecessorTile)) {
688 blockArg.
getOwner()->getParentOp()->emitOpError(
689 "block argument not allocated to the same SME virtial tile as "
694 return success(!tileMismatch);
698 auto resolveTrivialTileConflicts = [&](
Value value) -> LogicalResult {
701 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
707 tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
710 tileOp.emitOpError(
"tile operand allocated to different SME "
711 "virtial tile (move required)");
712 error.attachNote(tileOperand->get().getLoc())
713 <<
"tile operand is: " << tileOperand->get();
718 auto clonedOp = operandTileOp.clone();
720 [&] { clonedOp.setTileId(tileOp.getTileId()); });
721 rewriter.
insert(clonedOp);
722 if (isa<CopyTileOp>(tileOp)) {
724 clonedOp->getResult(0));
727 tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
732 for (
Value value : liveRange->values) {
734 assignTileIdToValue(rewriter, value, tileIdAttr);
737 if (succeeded(foldRedundantCopies(value)))
741 if (
failed(validateBlockArguments(value)))
745 if (
failed(resolveTrivialTileConflicts(value)))
755 FunctionOpInterface
function) {
756 llvm::errs() <<
"SME Tile Liveness: @" <<
function.getName()
757 <<
"\nKey:\nS - Start\nE - End\n| - Live\n";
759 llvm::errs() <<
"^bb" << blockIdx <<
":\n";
761 unsigned operationIndex = operationToIndexMap.at(&op);
762 for (LiveRange
const *range : liveRanges) {
764 for (
auto it = range->ranges->begin(); it != range->ranges->end();
766 if (it.start() == operationIndex)
767 liveness = (liveness ==
'E' ?
'|' :
'S');
768 else if (it.stop() == operationIndex)
769 liveness = (liveness ==
'S' ?
'|' :
'E');
770 else if (operationIndex >= it.start() && operationIndex < it.stop())
773 llvm::errs() << liveness;
775 llvm::errs() <<
' ' << op.getName() <<
'\n';
778 llvm::errs() <<
"==========\n";
781 struct TestTileAllocationPass
782 :
public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
783 using TestTileAllocationBase::TestTileAllocationBase;
784 void runOnOperation()
override {
785 FunctionOpInterface
function = getOperation();
786 if (preprocessOnly) {
788 return preprocessForTileAllocation(rewriter,
function);
798 if (
function.empty()) {
803 LiveRange::Allocator liveRangeAllocator;
807 preprocessForTileAllocation(rewriter,
function);
811 auto operationToIndexMap = generateOperationNumbering(
function);
812 auto initialLiveRanges = gatherTileLiveRanges(
813 operationToIndexMap, liveRangeAllocator, liveness,
function);
814 if (initialLiveRanges.empty())
819 auto nonEmpty = llvm::make_filter_range(
820 llvm::make_second_range(initialLiveRanges),
821 [&](LiveRange
const &liveRange) {
return !liveRange.empty(); });
822 auto initialRanges = llvm::to_vector(llvm::map_range(
823 nonEmpty, [](LiveRange
const &liveRange) {
return &liveRange; }));
824 llvm::sort(initialRanges,
825 [](LiveRange
const *a, LiveRange
const *b) {
return *a < *b; });
826 llvm::errs() <<
"\n========== Initial Live Ranges:\n";
827 dumpLiveRanges(operationToIndexMap, initialRanges,
function);
833 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
836 llvm::errs() <<
"\n========== Coalesced Live Ranges:\n";
837 dumpLiveRanges(operationToIndexMap, coalescedLiveRanges,
function);
841 allocateTilesToLiveRanges(coalescedLiveRanges);
844 if (
failed(assignTileIdsAndResolveTrivialConflicts(rewriter,
function,
845 coalescedLiveRanges))) {
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
BlockArgListType getArguments()
IntegerAttr getI32IntegerAttr(int32_t value)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Represents an analysis for computing liveness information from a given top-level operation.
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
@ LLVM_MARK_AS_BITMASK_ENUM
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
static constexpr unsigned kInMemoryTileIdBase
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool operator<(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.