21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/GenericIteratedDominanceFrontier.h"
27#define GEN_PASS_DEF_MEM2REG
28#include "mlir/Transforms/Passes.h.inc"
31#define DEBUG_TYPE "mem2reg"
131using BlockingUsesMap =
132 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
133using RegionBlockingUsesMap =
134 llvm::SmallMapVector<Region *, BlockingUsesMap, 2>;
140struct RegionPromotionInfo {
147struct MemorySlotPromotionInfo {
156 RegionBlockingUsesMap userToBlockingUses;
169class MemorySlotPromotionAnalyzer {
173 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
177 std::optional<MemorySlotPromotionInfo> computeInfo();
191 computeBlockingUses(RegionBlockingUsesMap &userToBlockingUses,
199 void computeMergePoints(
Region *region,
225class MemorySlotPromoter {
227 MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
229 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
231 BlockIndexCache &blockIndexCache);
238 std::optional<PromotableAllocationOpInterface> promoteSlot();
265 void promoteInRegion(
Region *region,
Value reachingDef);
270 void removeBlockingUses(
Region *region);
274 void removeUnusedItems();
278 Value getOrCreateDefaultValue();
281 PromotableAllocationOpInterface allocator;
302 llvm::SmallSetVector<Operation *, 8> toErase;
306 MemorySlotPromotionInfo info;
314 BlockIndexCache &blockIndexCache;
319MemorySlotPromoter::MemorySlotPromoter(
320 MemorySlot slot, PromotableAllocationOpInterface allocator,
323 BlockIndexCache &blockIndexCache)
324 : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
325 dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
326 blockIndexCache(blockIndexCache) {
328 auto isResultOrNewBlockArgument = [&]() {
330 return arg.getOwner()->getParentOp() == allocator;
334 assert(isResultOrNewBlockArgument() &&
335 "a slot must be a result of the allocator or an argument of the child "
336 "regions of the allocator");
340Value MemorySlotPromoter::getOrCreateDefaultValue() {
344 OpBuilder::InsertionGuard guard(builder);
346 return defaultValue = allocator.getDefaultValue(slot, builder);
349LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
350 RegionBlockingUsesMap &userToBlockingUses,
360 auto slotPtrRegionOp =
361 dyn_cast<RegionKindInterface>(slotPtrRegion->
getParentOp());
362 if (slotPtrRegionOp &&
370 SmallPtrSet<OpOperand *, 4> &blockingUses =
371 userToBlockingUses[use.getOwner()->getParentRegion()][use.getOwner()];
372 blockingUses.insert(&use);
376 RegionSet regionsWithDirectUse;
378 RegionSet regionsWithDirectStore;
387 for (Operation *user : forwardSlice) {
389 auto *blockingUsesMapIt = userToBlockingUses.find(user->getParentRegion());
390 if (blockingUsesMapIt == userToBlockingUses.end())
392 BlockingUsesMap &blockingUsesMap = blockingUsesMapIt->second;
393 auto *it = blockingUsesMap.find(user);
394 if (it == blockingUsesMap.end())
398 if (
auto aliaser = dyn_cast<PromotableAliaserInterface>(user))
401 SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
403 SmallVector<OpOperand *> newBlockingUses;
406 if (
auto promotable = dyn_cast<PromotableOpInterface>(user)) {
407 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
410 regionsWithDirectUse.insert(user->getParentRegion());
411 }
else if (
auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
418 MemorySlot aliasSlot =
420 if (!promotable.canUsesBeRemoved(aliasSlot, blockingUses, newBlockingUses,
427 if (promotable.storesTo(aliasSlot))
428 regionsWithDirectStore.insert(user->getParentRegion());
430 regionsWithDirectUse.insert(user->getParentRegion());
438 for (OpOperand *blockingUse : newBlockingUses) {
439 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
441 Operation *useOwner = blockingUse->getOwner();
442 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
444 newUserBlockingUseSet.insert(blockingUse);
450 auto visitRegions = [&](SmallVector<Region *> ®ionsToPropagateFrom,
451 bool hasValueStores) {
452 while (!regionsToPropagateFrom.empty()) {
453 Region *region = regionsToPropagateFrom.pop_back_val();
456 regionsToPromote.contains(region))
459 RegionPromotionInfo ®ionInfo = regionsToPromote[region];
460 regionInfo.hasValueStores = hasValueStores;
462 auto promotableParentOp =
463 dyn_cast<PromotableRegionOpInterface>(region->
getParentOp());
464 if (!promotableParentOp)
467 if (!promotableParentOp.isRegionPromotable(slot, region, hasValueStores))
478 SmallVector<Region *> regionsToPropagateFrom(regionsWithDirectStore.begin(),
479 regionsWithDirectStore.end());
480 if (
failed(visitRegions(regionsToPropagateFrom,
true)))
484 regionsToPropagateFrom.clear();
485 regionsToPropagateFrom.append(regionsWithDirectUse.begin(),
486 regionsWithDirectUse.end());
487 if (
failed(visitRegions(regionsToPropagateFrom,
false)))
494void MemorySlotPromotionAnalyzer::computeMergePoints(
501 idfCalculator.setDefiningBlocks(definingBlocks);
504 idfCalculator.calculate(mergePointsVec);
506 mergePoints.insert_range(mergePointsVec);
509bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
510 SmallPtrSetImpl<Block *> &mergePoints) {
511 for (
Block *mergePoint : mergePoints)
512 for (
Block *pred : mergePoint->getPredecessors())
513 if (!isa<BranchOpInterface>(pred->getTerminator()))
519std::optional<MemorySlotPromotionInfo>
520MemorySlotPromotionAnalyzer::computeInfo() {
521 MemorySlotPromotionInfo info;
529 if (
failed(computeBlockingUses(info.userToBlockingUses, info.regionsToPromote,
542 auto collectStoringBlocks = [&](Value ptr,
const MemorySlot &ptrSlot) {
543 for (OpOperand &use : ptr.
getUses()) {
544 Operation *user = use.getOwner();
545 if (
auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
546 if (storeOp.storesTo(ptrSlot))
550 collectStoringBlocks(slot.
ptr, slot);
551 for (
auto &[aliasPtr, aliasInfo] : info.aliasMap)
552 collectStoringBlocks(aliasPtr, aliasInfo.slot);
553 for (
auto &[region, regionInfo] : info.regionsToPromote)
554 if (regionInfo.hasValueStores)
561 for (
auto &[region, defBlocks] : definingBlocks)
562 computeMergePoints(region, defBlocks, info.mergePoints);
567 if (!areMergePointsUsable(info.mergePoints))
573Value MemorySlotPromoter::promoteInBlock(
Block *block, Value reachingDef) {
574 SmallVector<Operation *> blockOps;
576 blockOps.push_back(&op);
577 for (Operation *op : blockOps) {
579 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
580 if (info.userToBlockingUses[memOp->getParentRegion()].contains(memOp))
581 reachingDefs.insert({memOp, reachingDef});
583 MemorySlot aliasSlot =
585 if (memOp.storesTo(aliasSlot)) {
591 reachingDef = getOrCreateDefaultValue();
592 Value reachingDefAtStore = reachingDef;
593 if (slot.
ptr != aliasSlot.
ptr) {
598 reachingDef, aliasSlot, slot, info.aliasMap, builder);
599 assert(reachingDefAtStore &&
600 "projectSlotValueToAliasValue contract violation");
603 memOp.getStored(aliasSlot, builder, reachingDefAtStore, dataLayout);
604 assert(stored &&
"a memory operation storing to a slot must provide a "
605 "new definition of the slot");
609 replacedValuesMap[memOp] = stored;
610 if (aliasSlot.
ptr != slot.
ptr) {
612 slot, info.aliasMap, builder);
613 assert(stored &&
"projectAliasValueToSlotValue contract violation");
615 reachingDef = stored;
621 if (
auto promotableRegionOp = dyn_cast<PromotableRegionOpInterface>(op)) {
622 bool needsPromotion =
false;
623 bool hasValueStores =
false;
624 for (Region ®ion : op->getRegions()) {
625 auto regionInfoIt = info.regionsToPromote.find(®ion);
626 if (regionInfoIt == info.regionsToPromote.end())
628 needsPromotion =
true;
629 if (!regionInfoIt->second.hasValueStores)
632 hasValueStores =
true;
636 if (needsPromotion) {
637 llvm::SmallMapVector<Region *, Value, 2> regionsToProcess;
643 reachingDef = getOrCreateDefaultValue();
645 promotableRegionOp.setupPromotion(slot, reachingDef, hasValueStores,
649 for (Region ®ion : op->getRegions())
650 if (info.regionsToPromote.contains(®ion))
652 regionsToProcess.contains(®ion) &&
653 "reaching definition must be provided for a required region");
656 for (
auto &[region, reachingDef] : regionsToProcess) {
658 "region must be part of the operation");
659 if (!info.regionsToPromote.contains(region))
661 promoteInRegion(region, reachingDef);
668 for (Region ®ion : op->getRegions())
672 reachingDef = promotableRegionOp.finalizePromotion(
673 slot, reachingDef, hasValueStores, reachingAtBlockEnd, builder);
679 for (
auto &[region, reachingDef] : regionsToProcess)
680 removeBlockingUses(region);
685 reachingAtBlockEnd[block] = reachingDef;
689void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
691 promoteInBlock(®ion->
front(), reachingDef);
696 llvm::DomTreeNodeBase<Block> *block;
700 SmallVector<DfsJob> dfsStack;
704 dfsStack.emplace_back<DfsJob>(
705 {domTree.getNode(®ion->
front()), reachingDef});
707 while (!dfsStack.empty()) {
708 DfsJob job = dfsStack.pop_back_val();
709 Block *block = job.block->getBlock();
711 if (info.mergePoints.contains(block)) {
712 BlockArgument blockArgument =
714 job.reachingDef = blockArgument;
717 job.reachingDef = promoteInBlock(block, job.reachingDef);
719 if (
auto terminator = dyn_cast<BranchOpInterface>(block->
getTerminator())) {
720 for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
721 if (info.mergePoints.contains(blockOperand.get())) {
722 if (!job.reachingDef)
723 job.reachingDef = getOrCreateDefaultValue();
725 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
726 .append(job.reachingDef);
731 for (
auto *child : job.block->children())
732 dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
740 Block *regionEntryBlock) {
741 auto [it,
inserted] = blockIndexCache.try_emplace(regionEntryBlock);
748 for (
auto [
index, block] : llvm::enumerate(topologicalOrder))
749 blockIndices[block] =
index;
759 BlockIndexCache &blockIndexCache) {
772 size_t lhsBlockIndex = topoBlockIndices.at(
lhs->getBlock());
773 size_t rhsBlockIndex = topoBlockIndices.at(
rhs->getBlock());
774 if (lhsBlockIndex == rhsBlockIndex)
775 return lhs->isBeforeInBlock(
rhs);
776 return lhsBlockIndex < rhsBlockIndex;
780void MemorySlotPromoter::removeBlockingUses(Region *region) {
781 auto *blockingUsesMapIt = info.userToBlockingUses.find(region);
782 if (blockingUsesMapIt == info.userToBlockingUses.end())
784 BlockingUsesMap &blockingUsesMap = blockingUsesMapIt->second;
785 if (blockingUsesMap.empty())
791 region = blockingUsesMap.
front().first->getParentRegion();
793 for (
auto &[op, blockingUses] : blockingUsesMap)
794 assert(op->getParentRegion() == region &&
795 "all operations must still be in the same region");
798 llvm::SmallVector<Operation *> usersToRemoveUses(
799 llvm::make_first_range(blockingUsesMap));
805 for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
806 if (
auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
807 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
814 reachingDef = getOrCreateDefaultValue();
817 MemorySlot aliasSlot =
819 Value reachingDefAtBlockingUse = reachingDef;
820 if (aliasSlot.
ptr != slot.
ptr) {
824 reachingDef, aliasSlot, slot, info.aliasMap, builder);
825 assert(reachingDefAtBlockingUse &&
826 "projectSlotValueToAliasValue contract violation");
828 if (toPromoteMemOp.removeBlockingUses(
829 aliasSlot, blockingUsesMap[toPromote], builder,
830 reachingDefAtBlockingUse, dataLayout) == DeletionKind::Delete)
831 toErase.insert(toPromote);
832 if (toPromoteMemOp.storesTo(aliasSlot))
833 if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
834 replacedValues.push_back({toPromoteMemOp, replacedValue});
838 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
840 if (toPromoteBasic.removeBlockingUses(blockingUsesMap[toPromote],
841 builder) == DeletionKind::Delete)
842 toErase.insert(toPromote);
843 if (toPromoteBasic.requiresReplacedValues())
844 toVisitReplacedValues.push_back(toPromoteBasic);
848void MemorySlotPromoter::removeUnusedItems() {
854 SmallPtrSet<BlockArgument, 8> mergePointArgsUnused;
855 SmallVector<BlockArgument> usedMergePointArgsToProcess;
862 auto isDefinitelyUsed = [&](BlockArgument arg) {
863 for (
auto &use : arg.getUses()) {
864 if (llvm::is_contained(toErase, use.getOwner()))
870 auto branchOp = dyn_cast<BranchOpInterface>(use.getOwner());
874 std::optional<BlockArgument> successorArgument =
875 branchOp.getSuccessorBlockArgument(use.getOperandNumber());
876 if (!successorArgument)
879 if (!info.mergePoints.contains(successorArgument->getOwner()))
885 bool isLastBlockArgument =
886 successorArgument->getArgNumber() ==
887 successorArgument->getOwner()->getNumArguments() - 1;
888 if (!isLastBlockArgument)
894 for (
Block *mergePoint : info.mergePoints) {
896 if (isDefinitelyUsed(arg))
897 usedMergePointArgsToProcess.push_back(arg);
899 mergePointArgsUnused.insert(arg);
904 while (!usedMergePointArgsToProcess.empty()) {
905 BlockArgument arg = usedMergePointArgsToProcess.pop_back_val();
909 "merge point argument must be the last argument of the merge point");
911 for (BlockOperand &use : mergePoint->
getUses()) {
916 auto branch = cast<BranchOpInterface>(use.getOwner());
917 SuccessorOperands succOperands =
918 branch.getSuccessorOperands(use.getOperandNumber());
928 succOperands.
append(getOrCreateDefaultValue());
930 Value populatedValue = succOperands[arg.
getArgNumber()];
931 auto populatedValueAsArg = dyn_cast<BlockArgument>(populatedValue);
932 if (populatedValueAsArg &&
933 mergePointArgsUnused.erase(populatedValueAsArg))
934 usedMergePointArgsToProcess.push_back(populatedValueAsArg);
938 allocator.handleBlockArgument(slot, arg, builder);
943 for (Operation *toEraseOp : toErase)
950 for (BlockArgument arg : mergePointArgsUnused) {
952 for (BlockOperand &use : mergePoint->
getUses()) {
953 auto branch = cast<BranchOpInterface>(use.getOwner());
954 SuccessorOperands succOperands =
955 branch.getSuccessorOperands(use.getOperandNumber());
962 for (BlockArgument arg : mergePointArgsUnused) {
968std::optional<PromotableAllocationOpInterface>
969MemorySlotPromoter::promoteSlot() {
985 for (PromotableOpInterface op : toVisitReplacedValues) {
987 op.visitReplacedValues(replacedValues, builder);
994 "after promotion, the slot pointer should not be used anymore");
996 LDBG() <<
"Promoted memory slot: " << slot.
ptr;
1001 return allocator.handlePromotionComplete(slot, defaultValue, builder);
1008 bool promotedAny =
false;
1013 BlockIndexCache blockIndexCache;
1018 newWorkList.reserve(workList.size());
1020 bool changesInThisRound =
false;
1021 for (PromotableAllocationOpInterface allocator : workList) {
1022 bool changedAllocator =
false;
1023 for (
MemorySlot slot : allocator.getPromotableSlots()) {
1027 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
1028 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
1030 std::optional<PromotableAllocationOpInterface> newAllocator =
1031 MemorySlotPromoter(slot, allocator, builder, dominance,
1032 dataLayout, std::move(*info), statistics,
1035 changedAllocator =
true;
1039 newWorkList.push_back(*newAllocator);
1046 if (!changedAllocator)
1047 newWorkList.push_back(allocator);
1048 changesInThisRound |= changedAllocator;
1050 if (!changesInThisRound)
1056 workList.swap(newWorkList);
1057 newWorkList.clear();
1065struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
1066 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
1068 void runOnOperation()
override {
1073 bool changed =
false;
1075 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1076 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
1077 auto &dominance = getAnalysis<DominanceInfo>();
1079 for (Region ®ion : scopeOp->
getRegions()) {
1085 SmallVector<PromotableAllocationOpInterface> allocators;
1087 region.
walk([&](PromotableAllocationOpInterface allocator) {
1088 allocators.emplace_back(allocator);
1093 dominance, statistics)))
1097 markAllAnalysesPreserved();
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void dominanceSort(SmallVector< Operation * > &ops, Region ®ion, BlockIndexCache &blockIndexCache)
Sorts ops according to dominance.
static const DenseMap< Block *, size_t > & getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Block *regionEntryBlock)
Gets or creates a block index mapping for the region of which the entry block is regionEntryBlock.
llvm::IDFCalculatorBase< Block, false > IDFCalculator
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
The main mechanism for performing data layout queries.
A class for computing basic dominance information.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
void append(ValueRange valueRange)
Add new operands that are forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
DomTree & getDomTree(Region *region) const
void invalidate()
Invalidate dominance info.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
bool referencesAtMostOneAliasOfSlot(Operation *op, const MemorySlot &rootSlot, const PromotableAliasMap &aliasMap)
Returns true if op's operands reach rootSlot through at most one distinct alias pointer (the root its...
Value convertSlotValueToAliasValue(Value slotValue, const MemorySlot &aliasSlot, const MemorySlot &rootSlot, const PromotableAliasMap &aliasMap, OpBuilder &builder)
Walks the alias chain from rootSlot down to aliasSlot.
void populatePromotableAliasMap(PromotableAliaserInterface aliaser, const MemorySlot &rootSlot, PromotableAliasMap &aliasMap)
Populates aliasMap with alias entries produced by aliaser for operands that already alias rootSlot.
llvm::SetVector< T, Vector, Set, N > SetVector
Value convertAliasValueToSlotValue(Value aliasValue, const MemorySlot &aliasSlot, Value rootReachingDef, const MemorySlot &rootSlot, const PromotableAliasMap &aliasMap, OpBuilder &builder)
Walks the alias chain from aliasSlot back up to rootSlot.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::SmallDenseMap< Value, PromotableSlotAliasInfo, 4 > PromotableAliasMap
Maps an alias slot pointer (a result of a PromotableAliaserInterface op) reachable from a root slot t...
std::optional< MemorySlot > getOpAliasSlot(Operation *op, const MemorySlot &rootSlot, const PromotableAliasMap &aliasMap)
Returns a MemorySlot for the operand of op that aliases rootSlot.ptr (either the root itself or a kno...
LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, OpBuilder &builder, const DataLayout &dataLayout, DominanceInfo &dominance, Mem2RegStatistics statistics={})
Attempts to promote the memory slots of the provided allocators.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Statistics collected while applying mem2reg.
llvm::Statistic * promotedAmount
Total amount of memory slots promoted.
llvm::Statistic * newBlockArgumentAmount
Total amount of new block arguments inserted in blocks.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.