20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
25 #define GEN_PASS_DEF_MEM2REG
26 #include "mlir/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "mem2reg"
100 using BlockingUsesMap =
101 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
105 struct MemorySlotPromotionInfo {
113 BlockingUsesMap userToBlockingUses;
119 class MemorySlotPromotionAnalyzer {
123 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
127 std::optional<MemorySlotPromotionInfo> computeInfo();
137 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
164 class MemorySlotPromoter {
166 MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
168 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
187 void computeReachingDefInRegion(
Region *region,
Value reachingDef);
190 void removeBlockingUses();
194 Value getOrCreateDefaultValue();
197 PromotableAllocationOpInterface allocator;
208 MemorySlotPromotionInfo info;
214 MemorySlotPromoter::MemorySlotPromoter(
215 MemorySlot slot, PromotableAllocationOpInterface allocator,
217 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
219 : slot(slot), allocator(allocator), rewriter(rewriter),
220 dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
221 statistics(statistics) {
223 auto isResultOrNewBlockArgument = [&]() {
225 return arg.getOwner()->getParentOp() == allocator;
229 assert(isResultOrNewBlockArgument() &&
230 "a slot must be a result of the allocator or an argument of the child "
231 "regions of the allocator");
235 Value MemorySlotPromoter::getOrCreateDefaultValue() {
239 RewriterBase::InsertionGuard guard(rewriter);
241 return defaultValue = allocator.getDefaultValue(slot, rewriter);
244 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
245 BlockingUsesMap &userToBlockingUses) {
255 userToBlockingUses[use.getOwner()];
256 blockingUses.insert(&use);
268 if (!userToBlockingUses.contains(user))
276 if (
auto promotable = dyn_cast<PromotableOpInterface>(user)) {
277 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
280 }
else if (
auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
281 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
291 for (
OpOperand *blockingUse : newBlockingUses) {
292 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
295 userToBlockingUses[blockingUse->getOwner()];
296 newUserBlockingUseSet.insert(blockingUse);
304 for (
auto &[toPromote, _] : userToBlockingUses)
305 if (isa<PromotableMemOpInterface>(toPromote) &&
326 if (visited.contains(user->getBlock()))
328 visited.insert(user->getBlock());
331 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
334 if (memOp.loadsFrom(slot)) {
335 liveInWorkList.push_back(user->getBlock());
341 if (memOp.storesTo(slot))
349 while (!liveInWorkList.empty()) {
350 Block *liveInBlock = liveInWorkList.pop_back_val();
352 if (!liveIn.insert(liveInBlock).second)
362 if (!definingBlocks.contains(pred))
363 liveInWorkList.push_back(pred);
370 void MemorySlotPromotionAnalyzer::computeMergePoints(
379 if (
auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
380 if (storeOp.storesTo(slot))
381 definingBlocks.insert(user->getBlock());
383 idfCalculator.setDefiningBlocks(definingBlocks);
386 idfCalculator.setLiveInBlocks(liveIn);
389 idfCalculator.calculate(mergePointsVec);
391 mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end());
394 bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
396 for (
Block *mergePoint : mergePoints)
398 if (!isa<BranchOpInterface>(pred->getTerminator()))
404 std::optional<MemorySlotPromotionInfo>
405 MemorySlotPromotionAnalyzer::computeInfo() {
406 MemorySlotPromotionInfo info;
412 if (
failed(computeBlockingUses(info.userToBlockingUses)))
418 computeMergePoints(info.mergePoints);
423 if (!areMergePointsUsable(info.mergePoints))
429 Value MemorySlotPromoter::computeReachingDefInBlock(
Block *block,
433 blockOps.push_back(&op);
435 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
436 if (info.userToBlockingUses.contains(memOp))
437 reachingDefs.insert({memOp, reachingDef});
439 if (memOp.storesTo(slot)) {
441 Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout);
442 assert(stored &&
"a memory operation storing to a slot must provide a "
443 "new definition of the slot");
444 reachingDef = stored;
445 replacedValuesMap[memOp] = stored;
453 void MemorySlotPromoter::computeReachingDefInRegion(
Region *region,
455 assert(reachingDef &&
"expected an initial reaching def to be provided");
457 computeReachingDefInBlock(®ion->
front(), reachingDef);
462 llvm::DomTreeNodeBase<Block> *block;
470 dfsStack.emplace_back<DfsJob>(
471 {domTree.getNode(®ion->
front()), reachingDef});
473 while (!dfsStack.empty()) {
474 DfsJob job = dfsStack.pop_back_val();
475 Block *block = job.block->getBlock();
477 if (info.mergePoints.contains(block)) {
486 argTypes.push_back(arg.getType());
487 argLocs.push_back(arg.getLoc());
493 info.mergePoints.
erase(block);
494 info.mergePoints.insert(newBlock);
504 allocator.handleBlockArgument(slot, blockArgument, rewriter);
505 job.reachingDef = blockArgument;
511 job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
512 assert(job.reachingDef);
514 if (
auto terminator = dyn_cast<BranchOpInterface>(block->
getTerminator())) {
515 for (
BlockOperand &blockOperand : terminator->getBlockOperands()) {
516 if (info.mergePoints.contains(blockOperand.get())) {
518 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
519 .append(job.reachingDef);
525 for (
auto *child : job.block->children())
526 dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
538 topoBlockIndices[block] = index;
544 size_t lhsBlockIndex = topoBlockIndices.at(lhs->
getBlock());
545 size_t rhsBlockIndex = topoBlockIndices.at(rhs->
getBlock());
546 if (lhsBlockIndex == rhsBlockIndex)
548 return lhsBlockIndex < rhsBlockIndex;
552 void MemorySlotPromoter::removeBlockingUses() {
554 llvm::make_first_range(info.userToBlockingUses));
564 for (
Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
565 if (
auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
566 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
570 reachingDef = getOrCreateDefaultValue();
573 if (toPromoteMemOp.removeBlockingUses(
574 slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
575 dataLayout) == DeletionKind::Delete)
576 toErase.push_back(toPromote);
577 if (toPromoteMemOp.storesTo(slot))
578 if (
Value replacedValue = replacedValuesMap[toPromoteMemOp])
579 replacedValuesList.push_back({toPromoteMemOp, replacedValue});
583 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
585 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
586 rewriter) == DeletionKind::Delete)
587 toErase.push_back(toPromote);
588 if (toPromoteBasic.requiresReplacedValues())
589 toVisit.push_back(toPromoteBasic);
591 for (PromotableOpInterface op : toVisit) {
593 op.visitReplacedValues(replacedValuesList, rewriter);
600 "after promotion, the slot pointer should not be used anymore");
603 void MemorySlotPromoter::promoteSlot() {
605 getOrCreateDefaultValue());
608 removeBlockingUses();
612 for (
Block *mergePoint : info.mergePoints) {
614 auto user = cast<BranchOpInterface>(use.getOwner());
616 user.getSuccessorOperands(use.getOperandNumber());
621 user, [&]() { succOperands.
append(getOrCreateDefaultValue()); });
625 LLVM_DEBUG(llvm::dbgs() <<
"[mem2reg] Promoted memory slot: " << slot.
ptr
631 allocator.handlePromotionComplete(slot, defaultValue, rewriter);
638 bool promotedAny =
false;
640 for (PromotableAllocationOpInterface allocator : allocators) {
641 for (
MemorySlot slot : allocator.getPromotableSlots()) {
646 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
647 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
649 MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
650 std::move(*info), statistics)
662 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
663 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
665 void runOnOperation()
override {
670 bool changed =
false;
684 region.
walk([&](PromotableAllocationOpInterface allocator) {
685 allocators.emplace_back(allocator);
688 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
689 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
697 getAnalysisManager().invalidate({});
701 markAllAnalysesPreserved();
static void dominanceSort(SmallVector< Operation * > &ops, Region ®ion)
Sorts ops according to dominance.
llvm::IDFCalculatorBase< Block, false > IDFCalculator
static SetVector< llvm::BasicBlock * > getTopologicallySortedBlocks(ArrayRef< llvm::BasicBlock * > basicBlocks)
Get a topologically sorted list of blocks for the given basic blocks.
This class represents an argument of a Block.
A block operand represents an operand that holds a reference to a Block, e.g.
Block represents an ordered list of Operations.
unsigned getNumArguments()
void erase()
Unlink this Block from its parent region and delete it.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
BlockArgListType getArguments()
The main mechanism for performing data layout queries.
A class for computing basic dominance information.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class models how operands are forwarded to block arguments in control flow.
void append(ValueRange valueRange)
Add new operands that are forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
DomTree & getDomTree(Region *region) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, RewriterBase &rewriter, const DataLayout &dataLayout, Mem2RegStatistics statistics={})
Attempts to promote the memory slots of the provided allocators.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
This class represents an efficient way to signal success or failure.
Statistics collected while applying mem2reg.
llvm::Statistic * promotedAmount
Total amount of memory slots promoted.
llvm::Statistic * newBlockArgumentAmount
Total amount of new block arguments inserted in blocks.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.