21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/GenericIteratedDominanceFrontier.h"
27#define GEN_PASS_DEF_MEM2REG
28#include "mlir/Transforms/Passes.h.inc"
31#define DEBUG_TYPE "mem2reg"
131using BlockingUsesMap =
132 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
133using RegionBlockingUsesMap =
134 llvm::SmallMapVector<Region *, BlockingUsesMap, 2>;
140struct RegionPromotionInfo {
147struct MemorySlotPromotionInfo {
156 RegionBlockingUsesMap userToBlockingUses;
166class MemorySlotPromotionAnalyzer {
170 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
174 std::optional<MemorySlotPromotionInfo> computeInfo();
187 LogicalResult computeBlockingUses(
188 RegionBlockingUsesMap &userToBlockingUses,
195 void computeMergePoints(
Region *region,
221class MemorySlotPromoter {
223 MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
225 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
227 BlockIndexCache &blockIndexCache);
234 std::optional<PromotableAllocationOpInterface> promoteSlot();
261 void promoteInRegion(
Region *region,
Value reachingDef);
266 void removeBlockingUses(
Region *region);
270 void removeUnusedItems();
274 Value getOrCreateDefaultValue();
277 PromotableAllocationOpInterface allocator;
298 llvm::SmallSetVector<Operation *, 8> toErase;
302 MemorySlotPromotionInfo info;
310 BlockIndexCache &blockIndexCache;
315MemorySlotPromoter::MemorySlotPromoter(
316 MemorySlot slot, PromotableAllocationOpInterface allocator,
319 BlockIndexCache &blockIndexCache)
320 : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
321 dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
322 blockIndexCache(blockIndexCache) {
324 auto isResultOrNewBlockArgument = [&]() {
326 return arg.getOwner()->getParentOp() == allocator;
330 assert(isResultOrNewBlockArgument() &&
331 "a slot must be a result of the allocator or an argument of the child "
332 "regions of the allocator");
336Value MemorySlotPromoter::getOrCreateDefaultValue() {
340 OpBuilder::InsertionGuard guard(builder);
342 return defaultValue = allocator.getDefaultValue(slot, builder);
345LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
346 RegionBlockingUsesMap &userToBlockingUses,
355 auto slotPtrRegionOp =
356 dyn_cast<RegionKindInterface>(slotPtrRegion->
getParentOp());
357 if (slotPtrRegionOp &&
365 SmallPtrSet<OpOperand *, 4> &blockingUses =
366 userToBlockingUses[use.getOwner()->getParentRegion()][use.getOwner()];
367 blockingUses.insert(&use);
371 RegionSet regionsWithDirectUse;
373 RegionSet regionsWithDirectStore;
382 for (Operation *user : forwardSlice) {
384 auto *blockingUsesMapIt = userToBlockingUses.find(user->getParentRegion());
385 if (blockingUsesMapIt == userToBlockingUses.end())
387 BlockingUsesMap &blockingUsesMap = blockingUsesMapIt->second;
388 auto *it = blockingUsesMap.find(user);
389 if (it == blockingUsesMap.end())
392 SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
394 SmallVector<OpOperand *> newBlockingUses;
397 if (
auto promotable = dyn_cast<PromotableOpInterface>(user)) {
398 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
401 regionsWithDirectUse.insert(user->getParentRegion());
402 }
else if (
auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
403 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
410 if (promotable.storesTo(slot))
411 regionsWithDirectStore.insert(user->getParentRegion());
413 regionsWithDirectUse.insert(user->getParentRegion());
421 for (OpOperand *blockingUse : newBlockingUses) {
422 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
424 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
425 blockingUsesMap[blockingUse->getOwner()];
426 newUserBlockingUseSet.insert(blockingUse);
432 auto visitRegions = [&](SmallVector<Region *> ®ionsToPropagateFrom,
433 bool hasValueStores) {
434 while (!regionsToPropagateFrom.empty()) {
435 Region *region = regionsToPropagateFrom.pop_back_val();
438 regionsToPromote.contains(region))
441 RegionPromotionInfo ®ionInfo = regionsToPromote[region];
442 regionInfo.hasValueStores = hasValueStores;
444 auto promotableParentOp =
445 dyn_cast<PromotableRegionOpInterface>(region->
getParentOp());
446 if (!promotableParentOp)
449 if (!promotableParentOp.isRegionPromotable(slot, region, hasValueStores))
460 SmallVector<Region *> regionsToPropagateFrom(regionsWithDirectStore.begin(),
461 regionsWithDirectStore.end());
462 if (
failed(visitRegions(regionsToPropagateFrom,
true)))
466 regionsToPropagateFrom.clear();
467 regionsToPropagateFrom.append(regionsWithDirectUse.begin(),
468 regionsWithDirectUse.end());
469 if (
failed(visitRegions(regionsToPropagateFrom,
false)))
476void MemorySlotPromotionAnalyzer::computeMergePoints(
483 idfCalculator.setDefiningBlocks(definingBlocks);
486 idfCalculator.calculate(mergePointsVec);
488 mergePoints.insert_range(mergePointsVec);
491bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
492 SmallPtrSetImpl<Block *> &mergePoints) {
493 for (
Block *mergePoint : mergePoints)
494 for (
Block *pred : mergePoint->getPredecessors())
495 if (!isa<BranchOpInterface>(pred->getTerminator()))
501std::optional<MemorySlotPromotionInfo>
502MemorySlotPromotionAnalyzer::computeInfo() {
503 MemorySlotPromotionInfo info;
512 computeBlockingUses(info.userToBlockingUses, info.regionsToPromote)))
520 if (
auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
521 if (storeOp.storesTo(slot))
522 definingBlocks[user->getParentRegion()].insert(user->getBlock());
523 for (
auto &[region, regionInfo] : info.regionsToPromote)
524 if (regionInfo.hasValueStores)
531 for (
auto &[region, defBlocks] : definingBlocks)
532 computeMergePoints(region, defBlocks, info.mergePoints);
537 if (!areMergePointsUsable(info.mergePoints))
543Value MemorySlotPromoter::promoteInBlock(
Block *block, Value reachingDef) {
544 SmallVector<Operation *> blockOps;
546 blockOps.push_back(&op);
547 for (Operation *op : blockOps) {
549 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
550 if (info.userToBlockingUses[memOp->getParentRegion()].contains(memOp))
551 reachingDefs.insert({memOp, reachingDef});
553 if (memOp.storesTo(slot)) {
559 reachingDef = getOrCreateDefaultValue();
560 Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
561 assert(stored &&
"a memory operation storing to a slot must provide a "
562 "new definition of the slot");
563 reachingDef = stored;
564 replacedValuesMap[memOp] = stored;
570 if (
auto promotableRegionOp = dyn_cast<PromotableRegionOpInterface>(op)) {
571 bool needsPromotion =
false;
572 bool hasValueStores =
false;
573 for (Region ®ion : op->getRegions()) {
574 auto regionInfoIt = info.regionsToPromote.find(®ion);
575 if (regionInfoIt == info.regionsToPromote.end())
577 needsPromotion =
true;
578 if (!regionInfoIt->second.hasValueStores)
581 hasValueStores =
true;
585 if (needsPromotion) {
586 llvm::SmallMapVector<Region *, Value, 2> regionsToProcess;
592 reachingDef = getOrCreateDefaultValue();
594 promotableRegionOp.setupPromotion(slot, reachingDef, hasValueStores,
598 for (Region ®ion : op->getRegions())
599 if (info.regionsToPromote.contains(®ion))
601 regionsToProcess.contains(®ion) &&
602 "reaching definition must be provided for a required region");
605 for (
auto &[region, reachingDef] : regionsToProcess) {
607 "region must be part of the operation");
608 if (!info.regionsToPromote.contains(region))
610 promoteInRegion(region, reachingDef);
617 for (Region ®ion : op->getRegions())
621 reachingDef = promotableRegionOp.finalizePromotion(
622 slot, reachingDef, hasValueStores, reachingAtBlockEnd, builder);
628 for (
auto &[region, reachingDef] : regionsToProcess)
629 removeBlockingUses(region);
634 reachingAtBlockEnd[block] = reachingDef;
638void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
640 promoteInBlock(®ion->
front(), reachingDef);
645 llvm::DomTreeNodeBase<Block> *block;
649 SmallVector<DfsJob> dfsStack;
653 dfsStack.emplace_back<DfsJob>(
654 {domTree.getNode(®ion->
front()), reachingDef});
656 while (!dfsStack.empty()) {
657 DfsJob job = dfsStack.pop_back_val();
658 Block *block = job.block->getBlock();
660 if (info.mergePoints.contains(block)) {
661 BlockArgument blockArgument =
663 job.reachingDef = blockArgument;
666 job.reachingDef = promoteInBlock(block, job.reachingDef);
668 if (
auto terminator = dyn_cast<BranchOpInterface>(block->
getTerminator())) {
669 for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
670 if (info.mergePoints.contains(blockOperand.get())) {
671 if (!job.reachingDef)
672 job.reachingDef = getOrCreateDefaultValue();
674 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
675 .append(job.reachingDef);
680 for (
auto *child : job.block->children())
681 dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
689 Block *regionEntryBlock) {
690 auto [it,
inserted] = blockIndexCache.try_emplace(regionEntryBlock);
697 for (
auto [
index, block] : llvm::enumerate(topologicalOrder))
698 blockIndices[block] =
index;
708 BlockIndexCache &blockIndexCache) {
721 size_t lhsBlockIndex = topoBlockIndices.at(
lhs->getBlock());
722 size_t rhsBlockIndex = topoBlockIndices.at(
rhs->getBlock());
723 if (lhsBlockIndex == rhsBlockIndex)
724 return lhs->isBeforeInBlock(
rhs);
725 return lhsBlockIndex < rhsBlockIndex;
729void MemorySlotPromoter::removeBlockingUses(Region *region) {
730 auto *blockingUsesMapIt = info.userToBlockingUses.find(region);
731 if (blockingUsesMapIt == info.userToBlockingUses.end())
733 BlockingUsesMap &blockingUsesMap = blockingUsesMapIt->second;
734 if (blockingUsesMap.empty())
740 region = blockingUsesMap.
front().first->getParentRegion();
742 for (
auto &[op, blockingUses] : blockingUsesMap)
743 assert(op->getParentRegion() == region &&
744 "all operations must still be in the same region");
747 llvm::SmallVector<Operation *> usersToRemoveUses(
748 llvm::make_first_range(blockingUsesMap));
754 for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
755 if (
auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
756 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
763 reachingDef = getOrCreateDefaultValue();
766 if (toPromoteMemOp.removeBlockingUses(slot, blockingUsesMap[toPromote],
767 builder, reachingDef,
768 dataLayout) == DeletionKind::Delete)
769 toErase.insert(toPromote);
770 if (toPromoteMemOp.storesTo(slot))
771 if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
772 replacedValues.push_back({toPromoteMemOp, replacedValue});
776 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
778 if (toPromoteBasic.removeBlockingUses(blockingUsesMap[toPromote],
779 builder) == DeletionKind::Delete)
780 toErase.insert(toPromote);
781 if (toPromoteBasic.requiresReplacedValues())
782 toVisitReplacedValues.push_back(toPromoteBasic);
786void MemorySlotPromoter::removeUnusedItems() {
792 SmallPtrSet<BlockArgument, 8> mergePointArgsUnused;
793 SmallVector<BlockArgument> usedMergePointArgsToProcess;
800 auto isDefinitelyUsed = [&](BlockArgument arg) {
801 for (
auto &use : arg.getUses()) {
802 if (llvm::is_contained(toErase, use.getOwner()))
808 auto branchOp = dyn_cast<BranchOpInterface>(use.getOwner());
812 std::optional<BlockArgument> successorArgument =
813 branchOp.getSuccessorBlockArgument(use.getOperandNumber());
814 if (!successorArgument)
817 if (!info.mergePoints.contains(successorArgument->getOwner()))
823 bool isLastBlockArgument =
824 successorArgument->getArgNumber() ==
825 successorArgument->getOwner()->getNumArguments() - 1;
826 if (!isLastBlockArgument)
832 for (
Block *mergePoint : info.mergePoints) {
834 if (isDefinitelyUsed(arg))
835 usedMergePointArgsToProcess.push_back(arg);
837 mergePointArgsUnused.insert(arg);
842 while (!usedMergePointArgsToProcess.empty()) {
843 BlockArgument arg = usedMergePointArgsToProcess.pop_back_val();
847 "merge point argument must be the last argument of the merge point");
849 for (BlockOperand &use : mergePoint->
getUses()) {
854 auto branch = cast<BranchOpInterface>(use.getOwner());
855 SuccessorOperands succOperands =
856 branch.getSuccessorOperands(use.getOperandNumber());
866 succOperands.
append(getOrCreateDefaultValue());
868 Value populatedValue = succOperands[arg.
getArgNumber()];
869 auto populatedValueAsArg = dyn_cast<BlockArgument>(populatedValue);
870 if (populatedValueAsArg &&
871 mergePointArgsUnused.erase(populatedValueAsArg))
872 usedMergePointArgsToProcess.push_back(populatedValueAsArg);
876 allocator.handleBlockArgument(slot, arg, builder);
881 for (Operation *toEraseOp : toErase)
888 for (BlockArgument arg : mergePointArgsUnused) {
890 for (BlockOperand &use : mergePoint->
getUses()) {
891 auto branch = cast<BranchOpInterface>(use.getOwner());
892 SuccessorOperands succOperands =
893 branch.getSuccessorOperands(use.getOperandNumber());
900 for (BlockArgument arg : mergePointArgsUnused) {
906std::optional<PromotableAllocationOpInterface>
907MemorySlotPromoter::promoteSlot() {
923 for (PromotableOpInterface op : toVisitReplacedValues) {
925 op.visitReplacedValues(replacedValues, builder);
932 "after promotion, the slot pointer should not be used anymore");
934 LDBG() <<
"Promoted memory slot: " << slot.
ptr;
939 return allocator.handlePromotionComplete(slot, defaultValue, builder);
946 bool promotedAny =
false;
951 BlockIndexCache blockIndexCache;
956 newWorkList.reserve(workList.size());
958 bool changesInThisRound =
false;
959 for (PromotableAllocationOpInterface allocator : workList) {
960 bool changedAllocator =
false;
961 for (
MemorySlot slot : allocator.getPromotableSlots()) {
965 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
966 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
968 std::optional<PromotableAllocationOpInterface> newAllocator =
969 MemorySlotPromoter(slot, allocator, builder, dominance,
970 dataLayout, std::move(*info), statistics,
973 changedAllocator =
true;
977 newWorkList.push_back(*newAllocator);
984 if (!changedAllocator)
985 newWorkList.push_back(allocator);
986 changesInThisRound |= changedAllocator;
988 if (!changesInThisRound)
994 workList.swap(newWorkList);
1003struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
1004 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
1006 void runOnOperation()
override {
1011 bool changed =
false;
1013 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1014 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
1015 auto &dominance = getAnalysis<DominanceInfo>();
1017 for (Region ®ion : scopeOp->
getRegions()) {
1023 SmallVector<PromotableAllocationOpInterface> allocators;
1025 region.
walk([&](PromotableAllocationOpInterface allocator) {
1026 allocators.emplace_back(allocator);
1031 dominance, statistics)))
1035 markAllAnalysesPreserved();
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void dominanceSort(SmallVector< Operation * > &ops, Region ®ion, BlockIndexCache &blockIndexCache)
Sorts ops according to dominance.
static const DenseMap< Block *, size_t > & getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Block *regionEntryBlock)
Gets or creates a block index mapping for the region of which the entry block is regionEntryBlock.
llvm::IDFCalculatorBase< Block, false > IDFCalculator
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
The main mechanism for performing data layout queries.
A class for computing basic dominance information.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
void append(ValueRange valueRange)
Add new operands that are forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
DomTree & getDomTree(Region *region) const
void invalidate()
Invalidate dominance info.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, OpBuilder &builder, const DataLayout &dataLayout, DominanceInfo &dominance, Mem2RegStatistics statistics={})
Attempts to promote the memory slots of the provided allocators.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Statistics collected while applying mem2reg.
llvm::Statistic * promotedAmount
Total amount of memory slots promoted.
llvm::Statistic * newBlockArgumentAmount
Total amount of new block arguments inserted in blocks.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.