21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/DebugLog.h"
23 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
26 #define GEN_PASS_DEF_MEM2REG
27 #include "mlir/Transforms/Passes.h.inc"
30 #define DEBUG_TYPE "mem2reg"
101 using BlockingUsesMap =
102 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
106 struct MemorySlotPromotionInfo {
114 BlockingUsesMap userToBlockingUses;
120 class MemorySlotPromotionAnalyzer {
124 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
128 std::optional<MemorySlotPromotionInfo> computeInfo();
138 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
167 class MemorySlotPromoter {
169 MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
171 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
173 BlockIndexCache &blockIndexCache);
180 std::optional<PromotableAllocationOpInterface> promoteSlot();
193 void computeReachingDefInRegion(
Region *region,
Value reachingDef);
196 void removeBlockingUses();
200 Value getOrCreateDefaultValue();
203 PromotableAllocationOpInterface allocator;
214 MemorySlotPromotionInfo info;
218 BlockIndexCache &blockIndexCache;
223 MemorySlotPromoter::MemorySlotPromoter(
224 MemorySlot slot, PromotableAllocationOpInterface allocator,
227 BlockIndexCache &blockIndexCache)
228 : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
229 dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
230 blockIndexCache(blockIndexCache) {
232 auto isResultOrNewBlockArgument = [&]() {
234 return arg.getOwner()->getParentOp() == allocator;
238 assert(isResultOrNewBlockArgument() &&
239 "a slot must be a result of the allocator or an argument of the child "
240 "regions of the allocator");
244 Value MemorySlotPromoter::getOrCreateDefaultValue() {
250 return defaultValue = allocator.getDefaultValue(slot, builder);
253 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
254 BlockingUsesMap &userToBlockingUses) {
265 auto slotPtrRegionOp =
266 dyn_cast<RegionKindInterface>(slotPtrRegion->
getParentOp());
267 if (slotPtrRegionOp &&
276 userToBlockingUses[use.getOwner()];
277 blockingUses.insert(&use);
289 auto it = userToBlockingUses.find(user);
290 if (it == userToBlockingUses.end())
298 if (
auto promotable = dyn_cast<PromotableOpInterface>(user)) {
299 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
302 }
else if (
auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
303 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
313 for (
OpOperand *blockingUse : newBlockingUses) {
314 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
317 userToBlockingUses[blockingUse->getOwner()];
318 newUserBlockingUseSet.insert(blockingUse);
326 for (
auto &[toPromote, _] : userToBlockingUses)
327 if (isa<PromotableMemOpInterface>(toPromote) &&
348 if (!visited.insert(user->getBlock()).second)
352 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
355 if (memOp.loadsFrom(slot)) {
356 liveInWorkList.push_back(user->getBlock());
362 if (memOp.storesTo(slot))
370 while (!liveInWorkList.empty()) {
371 Block *liveInBlock = liveInWorkList.pop_back_val();
373 if (!liveIn.insert(liveInBlock).second)
383 if (!definingBlocks.contains(pred))
384 liveInWorkList.push_back(pred);
391 void MemorySlotPromotionAnalyzer::computeMergePoints(
400 if (
auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
401 if (storeOp.storesTo(slot))
402 definingBlocks.insert(user->getBlock());
404 idfCalculator.setDefiningBlocks(definingBlocks);
407 idfCalculator.setLiveInBlocks(liveIn);
410 idfCalculator.calculate(mergePointsVec);
412 mergePoints.insert_range(mergePointsVec);
415 bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
417 for (
Block *mergePoint : mergePoints)
419 if (!isa<BranchOpInterface>(pred->getTerminator()))
425 std::optional<MemorySlotPromotionInfo>
426 MemorySlotPromotionAnalyzer::computeInfo() {
427 MemorySlotPromotionInfo info;
433 if (
failed(computeBlockingUses(info.userToBlockingUses)))
439 computeMergePoints(info.mergePoints);
444 if (!areMergePointsUsable(info.mergePoints))
450 Value MemorySlotPromoter::computeReachingDefInBlock(
Block *block,
454 blockOps.push_back(&op);
456 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
457 if (info.userToBlockingUses.contains(memOp))
458 reachingDefs.insert({memOp, reachingDef});
460 if (memOp.storesTo(slot)) {
462 Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
463 assert(stored &&
"a memory operation storing to a slot must provide a "
464 "new definition of the slot");
465 reachingDef = stored;
466 replacedValuesMap[memOp] = stored;
474 void MemorySlotPromoter::computeReachingDefInRegion(
Region *region,
476 assert(reachingDef &&
"expected an initial reaching def to be provided");
478 computeReachingDefInBlock(®ion->
front(), reachingDef);
483 llvm::DomTreeNodeBase<Block> *block;
491 dfsStack.emplace_back<DfsJob>(
492 {domTree.getNode(®ion->
front()), reachingDef});
494 while (!dfsStack.empty()) {
495 DfsJob job = dfsStack.pop_back_val();
496 Block *block = job.block->getBlock();
498 if (info.mergePoints.contains(block)) {
502 allocator.handleBlockArgument(slot, blockArgument, builder);
503 job.reachingDef = blockArgument;
509 job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
510 assert(job.reachingDef);
512 if (
auto terminator = dyn_cast<BranchOpInterface>(block->
getTerminator())) {
513 for (
BlockOperand &blockOperand : terminator->getBlockOperands()) {
514 if (info.mergePoints.contains(blockOperand.get())) {
515 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
516 .append(job.reachingDef);
521 for (
auto *child : job.block->children())
522 dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
529 auto [it, inserted] = blockIndexCache.try_emplace(region);
536 blockIndices[block] = index;
544 BlockIndexCache &blockIndexCache) {
554 size_t lhsBlockIndex = topoBlockIndices.at(lhs->
getBlock());
555 size_t rhsBlockIndex = topoBlockIndices.at(rhs->
getBlock());
556 if (lhsBlockIndex == rhsBlockIndex)
558 return lhsBlockIndex < rhsBlockIndex;
562 void MemorySlotPromoter::removeBlockingUses() {
564 llvm::make_first_range(info.userToBlockingUses));
575 for (
Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
576 if (
auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
577 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
581 reachingDef = getOrCreateDefaultValue();
584 if (toPromoteMemOp.removeBlockingUses(
585 slot, info.userToBlockingUses[toPromote], builder, reachingDef,
586 dataLayout) == DeletionKind::Delete)
587 toErase.push_back(toPromote);
588 if (toPromoteMemOp.storesTo(slot))
589 if (
Value replacedValue = replacedValuesMap[toPromoteMemOp])
590 replacedValuesList.push_back({toPromoteMemOp, replacedValue});
594 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
596 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
597 builder) == DeletionKind::Delete)
598 toErase.push_back(toPromote);
599 if (toPromoteBasic.requiresReplacedValues())
600 toVisit.push_back(toPromoteBasic);
602 for (PromotableOpInterface op : toVisit) {
604 op.visitReplacedValues(replacedValuesList, builder);
611 "after promotion, the slot pointer should not be used anymore");
614 std::optional<PromotableAllocationOpInterface>
615 MemorySlotPromoter::promoteSlot() {
617 getOrCreateDefaultValue());
620 removeBlockingUses();
624 for (
Block *mergePoint : info.mergePoints) {
626 auto user = cast<BranchOpInterface>(use.getOwner());
628 user.getSuccessorOperands(use.getOperandNumber());
632 succOperands.
append(getOrCreateDefaultValue());
636 LDBG() <<
"Promoted memory slot: " << slot.
ptr;
641 return allocator.handlePromotionComplete(slot, defaultValue, builder);
648 bool promotedAny =
false;
653 BlockIndexCache blockIndexCache;
658 newWorkList.reserve(workList.size());
660 bool changesInThisRound =
false;
661 for (PromotableAllocationOpInterface allocator : workList) {
662 bool changedAllocator =
false;
663 for (
MemorySlot slot : allocator.getPromotableSlots()) {
667 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
668 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
670 std::optional<PromotableAllocationOpInterface> newAllocator =
671 MemorySlotPromoter(slot, allocator, builder, dominance,
672 dataLayout, std::move(*info), statistics,
675 changedAllocator =
true;
679 newWorkList.push_back(*newAllocator);
686 if (!changedAllocator)
687 newWorkList.push_back(allocator);
688 changesInThisRound |= changedAllocator;
690 if (!changesInThisRound)
696 workList.swap(newWorkList);
700 return success(promotedAny);
705 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
706 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
708 void runOnOperation()
override {
715 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
716 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
717 auto &dominance = getAnalysis<DominanceInfo>();
727 region.
walk([&](PromotableAllocationOpInterface allocator) {
728 allocators.emplace_back(allocator);
733 dominance, statistics)))
737 markAllAnalysesPreserved();
static void dominanceSort(SmallVector< Operation * > &ops, Region ®ion, BlockIndexCache &blockIndexCache)
Sorts ops according to dominance.
static const DenseMap< Block *, size_t > & getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region)
Gets or creates a block index mapping for region.
llvm::IDFCalculatorBase< Block, false > IDFCalculator
This class represents an argument of a Block.
A block operand represents an operand that holds a reference to a Block, e.g.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
The main mechanism for performing data layout queries.
A class for computing basic dominance information.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class models how operands are forwarded to block arguments in control flow.
void append(ValueRange valueRange)
Add new operands that are forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
DomTree & getDomTree(Region *region) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, OpBuilder &builder, const DataLayout &dataLayout, DominanceInfo &dominance, Mem2RegStatistics statistics={})
Attempts to promote the memory slots of the provided allocators.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Statistics collected while applying mem2reg.
llvm::Statistic * promotedAmount
Total amount of memory slots promoted.
llvm::Statistic * newBlockArgumentAmount
Total amount of new block arguments inserted in blocks.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.