20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
24 #define GEN_PASS_DEF_MEM2REG
25 #include "mlir/Transforms/Passes.h.inc"
28 #define DEBUG_TYPE "mem2reg"
99 using BlockingUsesMap =
100 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
104 struct MemorySlotPromotionInfo {
112 BlockingUsesMap userToBlockingUses;
118 class MemorySlotPromotionAnalyzer {
122 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
126 std::optional<MemorySlotPromotionInfo> computeInfo();
136 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
165 class MemorySlotPromoter {
167 MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
169 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
171 BlockIndexCache &blockIndexCache);
178 std::optional<PromotableAllocationOpInterface> promoteSlot();
191 void computeReachingDefInRegion(
Region *region,
Value reachingDef);
194 void removeBlockingUses();
198 Value getOrCreateDefaultValue();
201 PromotableAllocationOpInterface allocator;
212 MemorySlotPromotionInfo info;
216 BlockIndexCache &blockIndexCache;
221 MemorySlotPromoter::MemorySlotPromoter(
222 MemorySlot slot, PromotableAllocationOpInterface allocator,
225 BlockIndexCache &blockIndexCache)
226 : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
227 dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
228 blockIndexCache(blockIndexCache) {
230 auto isResultOrNewBlockArgument = [&]() {
232 return arg.getOwner()->getParentOp() == allocator;
236 assert(isResultOrNewBlockArgument() &&
237 "a slot must be a result of the allocator or an argument of the child "
238 "regions of the allocator");
242 Value MemorySlotPromoter::getOrCreateDefaultValue() {
248 return defaultValue = allocator.getDefaultValue(slot, builder);
251 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
252 BlockingUsesMap &userToBlockingUses) {
262 userToBlockingUses[use.getOwner()];
263 blockingUses.insert(&use);
275 if (!userToBlockingUses.contains(user))
283 if (
auto promotable = dyn_cast<PromotableOpInterface>(user)) {
284 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
287 }
else if (
auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
288 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
298 for (
OpOperand *blockingUse : newBlockingUses) {
299 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
302 userToBlockingUses[blockingUse->getOwner()];
303 newUserBlockingUseSet.insert(blockingUse);
311 for (
auto &[toPromote, _] : userToBlockingUses)
312 if (isa<PromotableMemOpInterface>(toPromote) &&
333 if (visited.contains(user->getBlock()))
335 visited.insert(user->getBlock());
338 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
341 if (memOp.loadsFrom(slot)) {
342 liveInWorkList.push_back(user->getBlock());
348 if (memOp.storesTo(slot))
356 while (!liveInWorkList.empty()) {
357 Block *liveInBlock = liveInWorkList.pop_back_val();
359 if (!liveIn.insert(liveInBlock).second)
369 if (!definingBlocks.contains(pred))
370 liveInWorkList.push_back(pred);
377 void MemorySlotPromotionAnalyzer::computeMergePoints(
386 if (
auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
387 if (storeOp.storesTo(slot))
388 definingBlocks.insert(user->getBlock());
390 idfCalculator.setDefiningBlocks(definingBlocks);
393 idfCalculator.setLiveInBlocks(liveIn);
396 idfCalculator.calculate(mergePointsVec);
398 mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end());
401 bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
403 for (
Block *mergePoint : mergePoints)
405 if (!isa<BranchOpInterface>(pred->getTerminator()))
411 std::optional<MemorySlotPromotionInfo>
412 MemorySlotPromotionAnalyzer::computeInfo() {
413 MemorySlotPromotionInfo info;
419 if (failed(computeBlockingUses(info.userToBlockingUses)))
425 computeMergePoints(info.mergePoints);
430 if (!areMergePointsUsable(info.mergePoints))
436 Value MemorySlotPromoter::computeReachingDefInBlock(
Block *block,
440 blockOps.push_back(&op);
442 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
443 if (info.userToBlockingUses.contains(memOp))
444 reachingDefs.insert({memOp, reachingDef});
446 if (memOp.storesTo(slot)) {
448 Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
449 assert(stored &&
"a memory operation storing to a slot must provide a "
450 "new definition of the slot");
451 reachingDef = stored;
452 replacedValuesMap[memOp] = stored;
460 void MemorySlotPromoter::computeReachingDefInRegion(
Region *region,
462 assert(reachingDef &&
"expected an initial reaching def to be provided");
464 computeReachingDefInBlock(®ion->
front(), reachingDef);
469 llvm::DomTreeNodeBase<Block> *block;
477 dfsStack.emplace_back<DfsJob>(
478 {domTree.getNode(®ion->
front()), reachingDef});
480 while (!dfsStack.empty()) {
481 DfsJob job = dfsStack.pop_back_val();
482 Block *block = job.block->getBlock();
484 if (info.mergePoints.contains(block)) {
488 allocator.handleBlockArgument(slot, blockArgument, builder);
489 job.reachingDef = blockArgument;
495 job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
496 assert(job.reachingDef);
498 if (
auto terminator = dyn_cast<BranchOpInterface>(block->
getTerminator())) {
499 for (
BlockOperand &blockOperand : terminator->getBlockOperands()) {
500 if (info.mergePoints.contains(blockOperand.get())) {
501 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
502 .append(job.reachingDef);
507 for (
auto *child : job.block->children())
508 dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
515 auto [it, inserted] = blockIndexCache.try_emplace(region);
522 blockIndices[block] = index;
530 BlockIndexCache &blockIndexCache) {
540 size_t lhsBlockIndex = topoBlockIndices.at(lhs->
getBlock());
541 size_t rhsBlockIndex = topoBlockIndices.at(rhs->
getBlock());
542 if (lhsBlockIndex == rhsBlockIndex)
544 return lhsBlockIndex < rhsBlockIndex;
548 void MemorySlotPromoter::removeBlockingUses() {
550 llvm::make_first_range(info.userToBlockingUses));
561 for (
Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
562 if (
auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
563 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
567 reachingDef = getOrCreateDefaultValue();
570 if (toPromoteMemOp.removeBlockingUses(
571 slot, info.userToBlockingUses[toPromote], builder, reachingDef,
572 dataLayout) == DeletionKind::Delete)
573 toErase.push_back(toPromote);
574 if (toPromoteMemOp.storesTo(slot))
575 if (
Value replacedValue = replacedValuesMap[toPromoteMemOp])
576 replacedValuesList.push_back({toPromoteMemOp, replacedValue});
580 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
582 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
583 builder) == DeletionKind::Delete)
584 toErase.push_back(toPromote);
585 if (toPromoteBasic.requiresReplacedValues())
586 toVisit.push_back(toPromoteBasic);
588 for (PromotableOpInterface op : toVisit) {
590 op.visitReplacedValues(replacedValuesList, builder);
597 "after promotion, the slot pointer should not be used anymore");
600 std::optional<PromotableAllocationOpInterface>
601 MemorySlotPromoter::promoteSlot() {
603 getOrCreateDefaultValue());
606 removeBlockingUses();
610 for (
Block *mergePoint : info.mergePoints) {
612 auto user = cast<BranchOpInterface>(use.getOwner());
614 user.getSuccessorOperands(use.getOperandNumber());
618 succOperands.
append(getOrCreateDefaultValue());
622 LLVM_DEBUG(llvm::dbgs() <<
"[mem2reg] Promoted memory slot: " << slot.
ptr
628 return allocator.handlePromotionComplete(slot, defaultValue, builder);
635 bool promotedAny =
false;
640 BlockIndexCache blockIndexCache;
646 newWorkList.reserve(workList.size());
648 bool changesInThisRound =
false;
649 for (PromotableAllocationOpInterface allocator : workList) {
650 bool changedAllocator =
false;
651 for (
MemorySlot slot : allocator.getPromotableSlots()) {
655 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
656 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
658 std::optional<PromotableAllocationOpInterface> newAllocator =
659 MemorySlotPromoter(slot, allocator, builder, dominance,
660 dataLayout, std::move(*info), statistics,
663 changedAllocator =
true;
667 newWorkList.push_back(*newAllocator);
674 if (!changedAllocator)
675 newWorkList.push_back(allocator);
676 changesInThisRound |= changedAllocator;
678 if (!changesInThisRound)
684 workList.swap(newWorkList);
688 return success(promotedAny);
693 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
694 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
696 void runOnOperation()
override {
701 bool changed =
false;
703 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
704 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
705 auto &dominance = getAnalysis<DominanceInfo>();
715 region.
walk([&](PromotableAllocationOpInterface allocator) {
716 allocators.emplace_back(allocator);
721 dominance, statistics)))
725 markAllAnalysesPreserved();
static void dominanceSort(SmallVector< Operation * > &ops, Region ®ion, BlockIndexCache &blockIndexCache)
Sorts ops according to dominance.
static const DenseMap< Block *, size_t > & getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region)
Gets or creates a block index mapping for region.
llvm::IDFCalculatorBase< Block, false > IDFCalculator
This class represents an argument of a Block.
A block operand represents an operand that holds a reference to a Block, e.g.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
The main mechanism for performing data layout queries.
A class for computing basic dominance information.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class models how operands are forwarded to block arguments in control flow.
void append(ValueRange valueRange)
Add new operands that are forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
DomTree & getDomTree(Region *region) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, OpBuilder &builder, const DataLayout &dataLayout, DominanceInfo &dominance, Mem2RegStatistics statistics={})
Attempts to promote the memory slots of the provided allocators.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Statistics collected while applying mem2reg.
llvm::Statistic * promotedAmount
Total amount of memory slots promoted.
llvm::Statistic * newBlockArgumentAmount
Total amount of new block arguments inserted in blocks.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.