20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
26 #define GEN_PASS_DEF_MEM2REG
27 #include "mlir/Transforms/Passes.h.inc"
30 #define DEBUG_TYPE "mem2reg"
101 using BlockingUsesMap =
102 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
106 struct MemorySlotPromotionInfo {
114 BlockingUsesMap userToBlockingUses;
120 class MemorySlotPromotionAnalyzer {
123 : slot(slot), dominance(dominance) {}
127 std::optional<MemorySlotPromotionInfo> computeInfo();
137 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
163 class MemorySlotPromoter {
165 MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
167 MemorySlotPromotionInfo info,
186 void computeReachingDefInRegion(
Region *region,
Value reachingDef);
189 void removeBlockingUses();
193 Value getLazyDefaultValue();
196 PromotableAllocationOpInterface allocator;
205 MemorySlotPromotionInfo info;
211 MemorySlotPromoter::MemorySlotPromoter(
212 MemorySlot slot, PromotableAllocationOpInterface allocator,
215 : slot(slot), allocator(allocator), rewriter(rewriter),
216 dominance(dominance), info(std::move(info)), statistics(statistics) {
218 auto isResultOrNewBlockArgument = [&]() {
220 return arg.getOwner()->getParentOp() == allocator;
224 assert(isResultOrNewBlockArgument() &&
225 "a slot must be a result of the allocator or an argument of the child "
226 "regions of the allocator");
230 Value MemorySlotPromoter::getLazyDefaultValue() {
234 RewriterBase::InsertionGuard guard(rewriter);
236 return defaultValue = allocator.getDefaultValue(slot, rewriter);
239 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
240 BlockingUsesMap &userToBlockingUses) {
250 userToBlockingUses[use.getOwner()];
251 blockingUses.insert(&use);
263 if (!userToBlockingUses.contains(user))
271 if (
auto promotable = dyn_cast<PromotableOpInterface>(user)) {
272 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses))
274 }
else if (
auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
275 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses))
284 for (
OpOperand *blockingUse : newBlockingUses) {
285 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
288 userToBlockingUses[blockingUse->getOwner()];
289 newUserBlockingUseSet.insert(blockingUse);
297 for (
auto &[toPromote, _] : userToBlockingUses)
298 if (isa<PromotableMemOpInterface>(toPromote) &&
319 if (visited.contains(user->getBlock()))
321 visited.insert(user->getBlock());
324 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
327 if (memOp.loadsFrom(slot)) {
328 liveInWorkList.push_back(user->getBlock());
334 if (memOp.storesTo(slot))
342 while (!liveInWorkList.empty()) {
343 Block *liveInBlock = liveInWorkList.pop_back_val();
345 if (!liveIn.insert(liveInBlock).second)
355 if (!definingBlocks.contains(pred))
356 liveInWorkList.push_back(pred);
363 void MemorySlotPromotionAnalyzer::computeMergePoints(
372 if (
auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
373 if (storeOp.storesTo(slot))
374 definingBlocks.insert(user->getBlock());
376 idfCalculator.setDefiningBlocks(definingBlocks);
379 idfCalculator.setLiveInBlocks(liveIn);
382 idfCalculator.calculate(mergePointsVec);
384 mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end());
387 bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
389 for (
Block *mergePoint : mergePoints)
391 if (!isa<BranchOpInterface>(pred->getTerminator()))
397 std::optional<MemorySlotPromotionInfo>
398 MemorySlotPromotionAnalyzer::computeInfo() {
399 MemorySlotPromotionInfo info;
405 if (
failed(computeBlockingUses(info.userToBlockingUses)))
411 computeMergePoints(info.mergePoints);
416 if (!areMergePointsUsable(info.mergePoints))
422 Value MemorySlotPromoter::computeReachingDefInBlock(
Block *block,
426 blockOps.push_back(&op);
428 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
429 if (info.userToBlockingUses.contains(memOp))
430 reachingDefs.insert({memOp, reachingDef});
432 if (memOp.storesTo(slot)) {
434 Value stored = memOp.getStored(slot, rewriter);
435 assert(stored &&
"a memory operation storing to a slot must provide a "
436 "new definition of the slot");
437 reachingDef = stored;
445 void MemorySlotPromoter::computeReachingDefInRegion(
Region *region,
448 computeReachingDefInBlock(®ion->
front(), reachingDef);
453 llvm::DomTreeNodeBase<Block> *block;
461 dfsStack.emplace_back<DfsJob>(
462 {domTree.getNode(®ion->
front()), reachingDef});
464 while (!dfsStack.empty()) {
465 DfsJob job = dfsStack.pop_back_val();
466 Block *block = job.block->getBlock();
468 if (info.mergePoints.contains(block)) {
477 argTypes.push_back(arg.getType());
478 argLocs.push_back(arg.getLoc());
484 info.mergePoints.
erase(block);
485 info.mergePoints.insert(newBlock);
495 allocator.handleBlockArgument(slot, blockArgument, rewriter);
496 job.reachingDef = blockArgument;
502 job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
504 if (
auto terminator = dyn_cast<BranchOpInterface>(block->
getTerminator())) {
505 for (
BlockOperand &blockOperand : terminator->getBlockOperands()) {
506 if (info.mergePoints.contains(blockOperand.get())) {
507 if (!job.reachingDef)
508 job.reachingDef = getLazyDefaultValue();
510 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
511 .append(job.reachingDef);
517 for (
auto *child : job.block->children())
518 dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
530 topoBlockIndices[block] = index;
536 size_t lhsBlockIndex = topoBlockIndices.at(lhs->
getBlock());
537 size_t rhsBlockIndex = topoBlockIndices.at(rhs->
getBlock());
538 if (lhsBlockIndex == rhsBlockIndex)
540 return lhsBlockIndex < rhsBlockIndex;
544 void MemorySlotPromoter::removeBlockingUses() {
546 llvm::make_first_range(info.userToBlockingUses));
552 for (
Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
553 if (
auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
554 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
558 reachingDef = getLazyDefaultValue();
561 if (toPromoteMemOp.removeBlockingUses(
562 slot, info.userToBlockingUses[toPromote], rewriter,
563 reachingDef) == DeletionKind::Delete)
564 toErase.push_back(toPromote);
569 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
571 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
572 rewriter) == DeletionKind::Delete)
573 toErase.push_back(toPromote);
580 "after promotion, the slot pointer should not be used anymore");
583 void MemorySlotPromoter::promoteSlot() {
587 removeBlockingUses();
591 for (
Block *mergePoint : info.mergePoints) {
593 auto user = cast<BranchOpInterface>(use.getOwner());
595 user.getSuccessorOperands(use.getOperandNumber());
600 user, [&]() { succOperands.
append(getLazyDefaultValue()); });
604 LLVM_DEBUG(llvm::dbgs() <<
"[mem2reg] Promoted memory slot: " << slot.
ptr
610 allocator.handlePromotionComplete(slot, defaultValue, rewriter);
616 bool promotedAny =
false;
618 for (PromotableAllocationOpInterface allocator : allocators) {
619 for (
MemorySlot slot : allocator.getPromotableSlots()) {
624 MemorySlotPromotionAnalyzer analyzer(slot, dominance);
625 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
627 MemorySlotPromoter(slot, allocator, rewriter, dominance,
628 std::move(*info), statistics)
639 Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
647 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
648 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
650 void runOnOperation()
override {
static MLIRContext * getContext(OpFoldResult val)
static void dominanceSort(SmallVector< Operation * > &ops, Region ®ion)
Sorts ops according to dominance.
llvm::IDFCalculatorBase< Block, false > IDFCalculator
static SetVector< llvm::BasicBlock * > getTopologicallySortedBlocks(llvm::Function *func)
Get a topologically sorted list of blocks for the given function.
This class represents an argument of a Block.
A block operand represents an operand that holds a reference to a Block, e.g.
Block represents an ordered list of Operations.
unsigned getNumArguments()
void erase()
Unlink this Block from its parent region and delete it.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
BlockArgListType getArguments()
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
bool enableRegionSimplification
Perform control flow optimizations to the region tree after applying all patterns.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Pattern applying mem2reg to the regions of the operations on which it matches.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class models how operands are forwarded to block arguments in control flow.
void append(ValueRange valueRange)
Add new operands that are forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
DomTree & getDomTree(Region *region) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, RewriterBase &rewriter, Mem2RegStatistics statistics={})
Attempts to promote the memory slots of the provided allocators.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
This class represents an efficient way to signal success or failure.
Statistics collected while applying mem2reg.
llvm::Statistic * promotedAmount
Total amount of memory slots promoted.
llvm::Statistic * newBlockArgumentAmount
Total amount of new block arguments inserted in blocks.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.