21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
25 #define GEN_PASS_DEF_MEM2REG
26 #include "mlir/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "mem2reg"
100 using BlockingUsesMap =
101 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
105 struct MemorySlotPromotionInfo {
113 BlockingUsesMap userToBlockingUses;
119 class MemorySlotPromotionAnalyzer {
123 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
127 std::optional<MemorySlotPromotionInfo> computeInfo();
137 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
166 class MemorySlotPromoter {
168 MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
170 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
172 BlockIndexCache &blockIndexCache);
179 std::optional<PromotableAllocationOpInterface> promoteSlot();
192 void computeReachingDefInRegion(
Region *region,
Value reachingDef);
195 void removeBlockingUses();
199 Value getOrCreateDefaultValue();
202 PromotableAllocationOpInterface allocator;
213 MemorySlotPromotionInfo info;
217 BlockIndexCache &blockIndexCache;
222 MemorySlotPromoter::MemorySlotPromoter(
223 MemorySlot slot, PromotableAllocationOpInterface allocator,
226 BlockIndexCache &blockIndexCache)
227 : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
228 dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
229 blockIndexCache(blockIndexCache) {
231 auto isResultOrNewBlockArgument = [&]() {
233 return arg.getOwner()->getParentOp() == allocator;
237 assert(isResultOrNewBlockArgument() &&
238 "a slot must be a result of the allocator or an argument of the child "
239 "regions of the allocator");
243 Value MemorySlotPromoter::getOrCreateDefaultValue() {
249 return defaultValue = allocator.getDefaultValue(slot, builder);
252 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
253 BlockingUsesMap &userToBlockingUses) {
264 auto slotPtrRegionOp =
265 dyn_cast<RegionKindInterface>(slotPtrRegion->
getParentOp());
266 if (slotPtrRegionOp &&
275 userToBlockingUses[use.getOwner()];
276 blockingUses.insert(&use);
288 auto it = userToBlockingUses.find(user);
289 if (it == userToBlockingUses.end())
297 if (
auto promotable = dyn_cast<PromotableOpInterface>(user)) {
298 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
301 }
else if (
auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
302 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
312 for (
OpOperand *blockingUse : newBlockingUses) {
313 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
316 userToBlockingUses[blockingUse->getOwner()];
317 newUserBlockingUseSet.insert(blockingUse);
325 for (
auto &[toPromote, _] : userToBlockingUses)
326 if (isa<PromotableMemOpInterface>(toPromote) &&
347 if (!visited.insert(user->getBlock()).second)
351 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
354 if (memOp.loadsFrom(slot)) {
355 liveInWorkList.push_back(user->getBlock());
361 if (memOp.storesTo(slot))
369 while (!liveInWorkList.empty()) {
370 Block *liveInBlock = liveInWorkList.pop_back_val();
372 if (!liveIn.insert(liveInBlock).second)
382 if (!definingBlocks.contains(pred))
383 liveInWorkList.push_back(pred);
390 void MemorySlotPromotionAnalyzer::computeMergePoints(
399 if (
auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
400 if (storeOp.storesTo(slot))
401 definingBlocks.insert(user->getBlock());
403 idfCalculator.setDefiningBlocks(definingBlocks);
406 idfCalculator.setLiveInBlocks(liveIn);
409 idfCalculator.calculate(mergePointsVec);
411 mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end());
414 bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
416 for (
Block *mergePoint : mergePoints)
418 if (!isa<BranchOpInterface>(pred->getTerminator()))
424 std::optional<MemorySlotPromotionInfo>
425 MemorySlotPromotionAnalyzer::computeInfo() {
426 MemorySlotPromotionInfo info;
432 if (failed(computeBlockingUses(info.userToBlockingUses)))
438 computeMergePoints(info.mergePoints);
443 if (!areMergePointsUsable(info.mergePoints))
449 Value MemorySlotPromoter::computeReachingDefInBlock(
Block *block,
453 blockOps.push_back(&op);
455 if (
auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
456 if (info.userToBlockingUses.contains(memOp))
457 reachingDefs.insert({memOp, reachingDef});
459 if (memOp.storesTo(slot)) {
461 Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
462 assert(stored &&
"a memory operation storing to a slot must provide a "
463 "new definition of the slot");
464 reachingDef = stored;
465 replacedValuesMap[memOp] = stored;
473 void MemorySlotPromoter::computeReachingDefInRegion(
Region *region,
475 assert(reachingDef &&
"expected an initial reaching def to be provided");
477 computeReachingDefInBlock(®ion->
front(), reachingDef);
482 llvm::DomTreeNodeBase<Block> *block;
490 dfsStack.emplace_back<DfsJob>(
491 {domTree.getNode(®ion->
front()), reachingDef});
493 while (!dfsStack.empty()) {
494 DfsJob job = dfsStack.pop_back_val();
495 Block *block = job.block->getBlock();
497 if (info.mergePoints.contains(block)) {
501 allocator.handleBlockArgument(slot, blockArgument, builder);
502 job.reachingDef = blockArgument;
508 job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
509 assert(job.reachingDef);
511 if (
auto terminator = dyn_cast<BranchOpInterface>(block->
getTerminator())) {
512 for (
BlockOperand &blockOperand : terminator->getBlockOperands()) {
513 if (info.mergePoints.contains(blockOperand.get())) {
514 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
515 .append(job.reachingDef);
520 for (
auto *child : job.block->children())
521 dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
528 auto [it, inserted] = blockIndexCache.try_emplace(region);
535 blockIndices[block] = index;
543 BlockIndexCache &blockIndexCache) {
553 size_t lhsBlockIndex = topoBlockIndices.at(lhs->
getBlock());
554 size_t rhsBlockIndex = topoBlockIndices.at(rhs->
getBlock());
555 if (lhsBlockIndex == rhsBlockIndex)
557 return lhsBlockIndex < rhsBlockIndex;
561 void MemorySlotPromoter::removeBlockingUses() {
563 llvm::make_first_range(info.userToBlockingUses));
574 for (
Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
575 if (
auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
576 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
580 reachingDef = getOrCreateDefaultValue();
583 if (toPromoteMemOp.removeBlockingUses(
584 slot, info.userToBlockingUses[toPromote], builder, reachingDef,
585 dataLayout) == DeletionKind::Delete)
586 toErase.push_back(toPromote);
587 if (toPromoteMemOp.storesTo(slot))
588 if (
Value replacedValue = replacedValuesMap[toPromoteMemOp])
589 replacedValuesList.push_back({toPromoteMemOp, replacedValue});
593 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
595 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
596 builder) == DeletionKind::Delete)
597 toErase.push_back(toPromote);
598 if (toPromoteBasic.requiresReplacedValues())
599 toVisit.push_back(toPromoteBasic);
601 for (PromotableOpInterface op : toVisit) {
603 op.visitReplacedValues(replacedValuesList, builder);
610 "after promotion, the slot pointer should not be used anymore");
613 std::optional<PromotableAllocationOpInterface>
614 MemorySlotPromoter::promoteSlot() {
616 getOrCreateDefaultValue());
619 removeBlockingUses();
623 for (
Block *mergePoint : info.mergePoints) {
625 auto user = cast<BranchOpInterface>(use.getOwner());
627 user.getSuccessorOperands(use.getOperandNumber());
631 succOperands.
append(getOrCreateDefaultValue());
635 LLVM_DEBUG(llvm::dbgs() <<
"[mem2reg] Promoted memory slot: " << slot.
ptr
641 return allocator.handlePromotionComplete(slot, defaultValue, builder);
648 bool promotedAny =
false;
653 BlockIndexCache blockIndexCache;
658 newWorkList.reserve(workList.size());
660 bool changesInThisRound =
false;
661 for (PromotableAllocationOpInterface allocator : workList) {
662 bool changedAllocator =
false;
663 for (
MemorySlot slot : allocator.getPromotableSlots()) {
667 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
668 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
670 std::optional<PromotableAllocationOpInterface> newAllocator =
671 MemorySlotPromoter(slot, allocator, builder, dominance,
672 dataLayout, std::move(*info), statistics,
675 changedAllocator =
true;
679 newWorkList.push_back(*newAllocator);
686 if (!changedAllocator)
687 newWorkList.push_back(allocator);
688 changesInThisRound |= changedAllocator;
690 if (!changesInThisRound)
696 workList.swap(newWorkList);
700 return success(promotedAny);
705 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
706 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
708 void runOnOperation()
override {
715 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
716 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
717 auto &dominance = getAnalysis<DominanceInfo>();
727 region.
walk([&](PromotableAllocationOpInterface allocator) {
728 allocators.emplace_back(allocator);
733 dominance, statistics)))
737 markAllAnalysesPreserved();
static void dominanceSort(SmallVector< Operation * > &ops, Region ®ion, BlockIndexCache &blockIndexCache)
Sorts ops according to dominance.
static const DenseMap< Block *, size_t > & getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region)
Gets or creates a block index mapping for region.
llvm::IDFCalculatorBase< Block, false > IDFCalculator
This class represents an argument of a Block.
A block operand represents an operand that holds a reference to a Block, e.g.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
The main mechanism for performing data layout queries.
A class for computing basic dominance information.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class models how operands are forwarded to block arguments in control flow.
void append(ValueRange valueRange)
Add new operands that are forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
DomTree & getDomTree(Region *region) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, OpBuilder &builder, const DataLayout &dataLayout, DominanceInfo &dominance, Mem2RegStatistics statistics={})
Attempts to promote the memory slots of the provided allocators.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Statistics collected while applying mem2reg.
llvm::Statistic * promotedAmount
Total amount of memory slots promoted.
llvm::Statistic * newBlockArgumentAmount
Total amount of new block arguments inserted in blocks.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.