56 #include "llvm/ADT/IntervalMap.h"
57 #include "llvm/ADT/TypeSwitch.h"
61 #define GEN_PASS_DEF_TESTTILEALLOCATION
62 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
70 enum class TileMask : unsigned {
116 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
117 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
118 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
119 TileMask::kZA2S, TileMask::kZA3S};
120 static constexpr std::array ZA_D_MASKS = {
121 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
122 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
123 static constexpr std::array ZA_Q_MASKS = {
124 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
125 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
126 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
127 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
129 case ArmSMETileType::ZAB:
131 case ArmSMETileType::ZAH:
133 case ArmSMETileType::ZAS:
135 case ArmSMETileType::ZAD:
137 case ArmSMETileType::ZAQ:
140 llvm_unreachable(
"unknown type in getMasks");
143 class TileAllocator {
146 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
147 auto masks = getMasks(tileType);
149 if ((tilesInUse & tileMask) == TileMask::kNone) {
150 tilesInUse |= tileMask;
158 void acquireTileId(ArmSMETileType tileType,
unsigned tileId) {
159 TileMask tileMask = getMasks(tileType)[tileId];
160 assert((tilesInUse & tileMask) == TileMask::kNone &&
161 "cannot acquire allocated tile!");
162 tilesInUse |= tileMask;
166 void releaseTileId(ArmSMETileType tileType,
unsigned tileId) {
167 TileMask tileMask = getMasks(tileType)[tileId];
168 assert((tilesInUse & tileMask) == tileMask &&
169 "cannot release unallocated tile!");
170 tilesInUse ^= tileMask;
174 unsigned allocateInMemoryTileId() {
178 return nextInMemoryTileId++;
182 TileMask tilesInUse = TileMask::kNone;
203 void splitCondBranches(
IRRewriter &rewriter, FunctionOpInterface
function) {
205 function.walk([&](cf::CondBranchOp condBranch) {
206 if (llvm::any_of(condBranch->getOperands(), [&](
Value value) {
207 return isValidSMETileVectorType(value.getType());
209 worklist.push_back(condBranch);
215 rewriter.
create<cf::BranchOp>(loc, dest, args);
218 for (
auto condBranch : worklist) {
219 auto loc = condBranch.getLoc();
220 Block *block = condBranch->getBlock();
221 auto newTrueBranch = rewriter.
splitBlock(block, block->
end());
222 auto newFalseBranch = rewriter.
splitBlock(block, block->
end());
223 insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
224 condBranch.getTrueDestOperands());
225 insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
226 condBranch.getFalseDestOperands());
228 condBranch.getFalseDestOperandsMutable().clear();
229 condBranch.getTrueDestOperandsMutable().clear();
230 condBranch.setSuccessor(newTrueBranch, 0);
231 condBranch.setSuccessor(newFalseBranch, 1);
248 void insertCopiesAtBranches(
IRRewriter &rewriter,
249 FunctionOpInterface
function) {
250 for (
Block &block :
function.getBlocks()) {
252 if (!isa<cf::BranchOp>(terminator))
258 rewriter.
create<CopyTileOp>(terminator->
getLoc(), operand.get());
273 void preprocessForTileAllocation(
IRRewriter &rewriter,
274 FunctionOpInterface
function) {
275 splitCondBranches(rewriter,
function);
276 insertCopiesAtBranches(rewriter,
function);
285 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
286 llvm::IntervalMapHalfOpenInfo<unsigned>>;
287 using Allocator = RangeSet::Allocator;
289 static constexpr uint8_t kValidLiveRange = 0xff;
291 LiveRange(Allocator &allocator)
292 : ranges(std::make_unique<RangeSet>(allocator)) {}
295 bool overlaps(LiveRange
const &otherRange)
const {
296 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
302 bool overlaps(uint64_t point)
const {
303 return ranges->lookup(point) == kValidLiveRange;
307 void unionWith(LiveRange
const &otherRange) {
308 for (
auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
310 ranges->insert(it.start(), it.stop(), kValidLiveRange);
311 values.set_union(otherRange.values);
315 void insert(
Value value,
unsigned start,
unsigned end) {
316 values.insert(value);
318 ranges->insert(start, end, kValidLiveRange);
321 bool empty()
const {
return ranges->empty(); }
322 unsigned start()
const {
return ranges->start(); }
323 unsigned end()
const {
return ranges->stop(); }
324 bool operator<(LiveRange
const &other)
const {
325 return start() < other.start();
328 ArmSMETileType getTileType()
const {
337 std::unique_ptr<RangeSet> ranges;
340 std::optional<unsigned> tileId;
348 generateOperationNumbering(FunctionOpInterface
function) {
353 for (
Block *block : blocks) {
357 op.walk([&](ArmSMETileOpInterface nestedOp) {
358 assert(&op == nestedOp.getOperation() &&
359 "ArmSME tile allocation does not support nested regions");
362 operationToIndexMap.try_emplace(&op, index++);
365 return operationToIndexMap;
371 LiveRange::Allocator &liveRangeAllocator,
372 Liveness &liveness, FunctionOpInterface
function) {
373 assert(!operationToIndexMap.empty() &&
"expected operation numbering");
379 auto defineOrUpdateValueLiveRange = [&](
Value value,
Operation *firstUseOrDef,
381 bool liveAtBlockEntry =
false) {
385 auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
386 LiveRange &valueLiveRange = it->second;
387 auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
389 unsigned startOpIdx =
390 operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
391 unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
392 valueLiveRange.insert(value, startOpIdx, endOpIdx);
395 for (
Block &block :
function.getBlocks()) {
399 defineOrUpdateValueLiveRange(argument, &block.
front(), *livenessInfo,
402 for (
Value liveIn : livenessInfo->
in())
403 defineOrUpdateValueLiveRange(liveIn, &block.
front(), *livenessInfo,
407 for (
Value result : op.getResults())
408 defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
416 static void forEachPredecessorTileValue(
BlockArgument blockArg,
422 .Case<cf::BranchOp>([&](
auto branch) {
423 Value predecessorOperand = branch.getDestOperands()[argNumber];
424 callback(predecessorOperand);
426 .Case<cf::CondBranchOp>([&](
auto condBranch) {
427 if (condBranch.getFalseDest() == block) {
428 Value predecessorOperand =
429 condBranch.getFalseDestOperands()[argNumber];
430 callback(predecessorOperand);
432 if (condBranch.getTrueDest() == block) {
433 Value predecessorOperand =
434 condBranch.getTrueDestOperands()[argNumber];
435 callback(predecessorOperand);
445 for (
auto &[value, liveRange] : initialLiveRanges) {
446 liveRanges.insert({value, &liveRange});
452 auto mergeValuesIfNonOverlapping = [&](
Value a,
Value b) {
453 LiveRange *aLiveRange = liveRanges.at(a);
454 LiveRange *bLiveRange = liveRanges.at(b);
455 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
456 aLiveRange->unionWith(*bLiveRange);
457 for (
Value value : bLiveRange->values)
458 liveRanges[value] = aLiveRange;
463 auto unifyDefinitionsWithOperands = [&](
Value value) {
464 auto armSMEOp = value.
getDefiningOp<ArmSMETileOpInterface>();
467 for (
auto operand : armSMEOp->getOperands()) {
469 mergeValuesIfNonOverlapping(value, operand);
474 auto unifyBlockArgumentsWithPredecessors = [&](
Value value) {
475 auto blockArg = dyn_cast<BlockArgument>(value);
478 forEachPredecessorTileValue(blockArg, [&](
Value predecessorTile) {
479 mergeValuesIfNonOverlapping(blockArg, predecessorTile);
483 auto applyRule = [&](
auto rule) {
484 llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
488 applyRule(unifyBlockArgumentsWithPredecessors);
489 applyRule(unifyDefinitionsWithOperands);
493 for (
auto [_, liveRange] : liveRanges) {
494 if (!liveRange->empty())
495 uniqueLiveRanges.insert(liveRange);
499 auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
500 std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
501 [](LiveRange *a, LiveRange *b) { return *a < *b; });
502 return std::move(coalescedLiveRanges);
507 template <
typename OverlappingRangesIterator>
509 chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
510 LiveRange *newRange) {
512 auto isTrivialSpill = [&](LiveRange &allocatedRange) {
514 newRange->getTileType()) &&
515 allocatedRange.values.size() == 1 &&
517 allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
519 if (isTrivialSpill(*newRange))
521 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
522 if (trivialSpill != overlappingRanges.end())
523 return &*trivialSpill;
526 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
530 LiveRange &latestEndingLiveRange =
531 *std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
532 isSmallerTileTypeOrEndsEarlier);
533 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
534 return &latestEndingLiveRange;
539 void allocateTilesToLiveRanges(
541 TileAllocator tileAllocator;
552 for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
553 auto currentPoint = nextRange->start();
555 activeRanges.remove_if([&](LiveRange *activeRange) {
557 if (activeRange->end() <= currentPoint) {
558 tileAllocator.releaseTileId(activeRange->getTileType(),
559 *activeRange->tileId);
563 if (!activeRange->overlaps(currentPoint)) {
564 tileAllocator.releaseTileId(activeRange->getTileType(),
565 *activeRange->tileId);
566 inactiveRanges.insert(activeRange);
572 inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
574 if (inactiveRange->end() <= currentPoint) {
578 if (inactiveRange->overlaps(currentPoint)) {
579 tileAllocator.acquireTileId(inactiveRange->getTileType(),
580 *inactiveRange->tileId);
581 activeRanges.insert(inactiveRange);
591 for (LiveRange *inactiveRange : inactiveRanges) {
592 if (inactiveRange->overlaps(*nextRange)) {
595 tileAllocator.acquireTileId(inactiveRange->getTileType(),
596 *inactiveRange->tileId);
597 overlappingInactiveRanges.push_back(inactiveRange);
602 auto rangeTileType = nextRange->getTileType();
603 auto tileId = tileAllocator.allocateTileId(rangeTileType);
604 if (succeeded(tileId)) {
605 nextRange->tileId = *tileId;
608 auto allOverlappingRanges = llvm::concat<LiveRange>(
609 llvm::make_pointee_range(activeRanges.getArrayRef()),
610 llvm::make_pointee_range(overlappingInactiveRanges));
612 LiveRange *rangeToSpill =
613 chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
614 if (rangeToSpill != nextRange) {
616 tileAllocator.releaseTileId(rangeToSpill->getTileType(),
617 *rangeToSpill->tileId);
619 nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
621 if (!activeRanges.remove(rangeToSpill)) {
622 bool removed = inactiveRanges.remove(rangeToSpill);
623 assert(removed &&
"expected a range to be removed!");
627 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
632 activeRanges.insert(nextRange);
635 for (LiveRange *range : overlappingInactiveRanges) {
637 tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
644 IntegerAttr tileIdAttr) {
645 if (
auto tileOp = value.
getDefiningOp<ArmSMETileOpInterface>())
646 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
648 if (
auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
651 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
657 LogicalResult assignTileIdsAndResolveTrivialConflicts(
658 IRRewriter &rewriter, FunctionOpInterface
function,
660 for (LiveRange
const *liveRange : allocatedLiveRanges) {
662 auto isAllocatedToSameTile = [&](
Value value) {
663 if (
auto tileOp = value.
getDefiningOp<ArmSMETileOpInterface>();
664 tileOp && tileOp.getTileId() == tileIdAttr)
666 return liveRange->values.contains(value);
670 auto foldRedundantCopies = [&](
Value value) -> LogicalResult {
672 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
680 auto validateBlockArguments = [&](
Value value) {
681 auto blockArg = dyn_cast<BlockArgument>(value);
686 bool tileMismatch =
false;
687 forEachPredecessorTileValue(blockArg, [&](
Value predecessorTile) {
690 if (!isAllocatedToSameTile(predecessorTile)) {
691 blockArg.
getOwner()->getParentOp()->emitOpError(
692 "block argument not allocated to the same SME virtial tile as "
697 return success(!tileMismatch);
701 auto resolveTrivialTileConflicts = [&](
Value value) -> LogicalResult {
704 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
710 tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
713 tileOp.emitOpError(
"tile operand allocated to different SME "
714 "virtial tile (move required)");
715 error.attachNote(tileOperand->get().getLoc())
716 <<
"tile operand is: " << tileOperand->get();
721 auto clonedOp = operandTileOp.clone();
723 [&] { clonedOp.setTileId(tileOp.getTileId()); });
724 rewriter.
insert(clonedOp);
725 if (isa<CopyTileOp>(tileOp)) {
727 clonedOp->getResult(0));
730 tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
735 for (
Value value : liveRange->values) {
737 assignTileIdToValue(rewriter, value, tileIdAttr);
740 if (succeeded(foldRedundantCopies(value)))
744 if (failed(validateBlockArguments(value)))
748 if (failed(resolveTrivialTileConflicts(value)))
758 FunctionOpInterface
function) {
759 llvm::errs() <<
"SME Tile Liveness: @" <<
function.getName()
760 <<
"\nKey:\nS - Start\nE - End\n| - Live\n";
762 llvm::errs() <<
"^bb" << blockIdx <<
":\n";
764 unsigned operationIndex = operationToIndexMap.at(&op);
765 for (LiveRange
const *range : liveRanges) {
767 for (
auto it = range->ranges->begin(); it != range->ranges->end();
769 if (it.start() == operationIndex)
770 liveness = (liveness ==
'E' ?
'|' :
'S');
771 else if (it.stop() == operationIndex)
772 liveness = (liveness ==
'S' ?
'|' :
'E');
773 else if (operationIndex >= it.start() && operationIndex < it.stop())
776 llvm::errs() << liveness;
778 llvm::errs() <<
' ' << op.getName() <<
'\n';
781 llvm::errs() <<
"==========\n";
784 struct TestTileAllocationPass
785 :
public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
786 using TestTileAllocationBase::TestTileAllocationBase;
787 void runOnOperation()
override {
788 FunctionOpInterface
function = getOperation();
789 if (preprocessOnly) {
791 return preprocessForTileAllocation(rewriter,
function);
801 if (
function.empty()) {
806 LiveRange::Allocator liveRangeAllocator;
810 preprocessForTileAllocation(rewriter,
function);
814 auto operationToIndexMap = generateOperationNumbering(
function);
815 auto initialLiveRanges = gatherTileLiveRanges(
816 operationToIndexMap, liveRangeAllocator, liveness,
function);
817 if (initialLiveRanges.empty())
822 auto nonEmpty = llvm::make_filter_range(
823 llvm::make_second_range(initialLiveRanges),
824 [&](LiveRange
const &liveRange) {
return !liveRange.empty(); });
825 auto initialRanges = llvm::to_vector(llvm::map_range(
826 nonEmpty, [](LiveRange
const &liveRange) {
return &liveRange; }));
827 std::sort(initialRanges.begin(), initialRanges.end(),
828 [](LiveRange
const *a, LiveRange
const *b) { return *a < *b; });
829 llvm::errs() <<
"\n========== Initial Live Ranges:\n";
830 dumpLiveRanges(operationToIndexMap, initialRanges,
function);
836 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
839 llvm::errs() <<
"\n========== Coalesced Live Ranges:\n";
840 dumpLiveRanges(operationToIndexMap, coalescedLiveRanges,
function);
844 allocateTilesToLiveRanges(coalescedLiveRanges);
847 if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter,
function,
848 coalescedLiveRanges))) {
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
BlockArgListType getArguments()
IntegerAttr getI32IntegerAttr(int32_t value)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Represents an analysis for computing liveness information from a given top-level operation.
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
@ LLVM_MARK_AS_BITMASK_ENUM
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
static constexpr unsigned kInMemoryTileIdBase
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool operator<(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.