MLIR  18.0.0git
SROA.cpp
Go to the documentation of this file.
1 //===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/SROA.h"
13 #include "mlir/Transforms/Passes.h"
14 
15 namespace mlir {
16 #define GEN_PASS_DEF_SROA
17 #include "mlir/Transforms/Passes.h.inc"
18 } // namespace mlir
19 
20 #define DEBUG_TYPE "sroa"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 /// Information computed by destructurable memory slot analysis used to perform
27 /// actual destructuring of the slot. This struct is only constructed if
28 /// destructuring is possible, and contains the necessary data to perform it.
29 struct MemorySlotDestructuringInfo {
30  /// Set of the indices that are actually used when accessing the subelements.
31  SmallPtrSet<Attribute, 8> usedIndices;
32  /// Blocking uses of a given user of the memory slot that must be eliminated.
34  /// List of potentially indirect accessors of the memory slot that need
35  /// rewiring.
37 };
38 
39 } // namespace
40 
41 /// Computes information for slot destructuring. This will compute whether this
42 /// slot can be destructured and data to perform the destructuring. Returns
43 /// nothing if the slot cannot be destructured or if there is no useful work to
44 /// be done.
45 static std::optional<MemorySlotDestructuringInfo>
47  assert(isa<DestructurableTypeInterface>(slot.elemType));
48 
49  if (slot.ptr.use_empty())
50  return {};
51 
52  MemorySlotDestructuringInfo info;
53 
54  SmallVector<MemorySlot> usedSafelyWorklist;
55 
56  auto scheduleAsBlockingUse = [&](OpOperand &use) {
57  SmallPtrSetImpl<OpOperand *> &blockingUses =
58  info.userToBlockingUses.getOrInsertDefault(use.getOwner());
59  blockingUses.insert(&use);
60  };
61 
62  // Initialize the analysis with the immediate users of the slot.
63  for (OpOperand &use : slot.ptr.getUses()) {
64  if (auto accessor =
65  dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
66  if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist)) {
67  info.accessors.push_back(accessor);
68  continue;
69  }
70  }
71 
72  // If it cannot be shown that the operation uses the slot safely, maybe it
73  // can be promoted out of using the slot?
74  scheduleAsBlockingUse(use);
75  }
76 
78  while (!usedSafelyWorklist.empty()) {
79  MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val();
80  for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) {
81  if (!visited.insert(&subslotUse).second)
82  continue;
83  Operation *subslotUser = subslotUse.getOwner();
84 
85  if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
86  if (succeeded(memOp.ensureOnlySafeAccesses(mustBeUsedSafely,
87  usedSafelyWorklist)))
88  continue;
89 
90  // If it cannot be shown that the operation uses the slot safely, maybe it
91  // can be promoted out of using the slot?
92  scheduleAsBlockingUse(subslotUse);
93  }
94  }
95 
96  SetVector<Operation *> forwardSlice;
97  mlir::getForwardSlice(slot.ptr, &forwardSlice);
98  for (Operation *user : forwardSlice) {
99  // If the next operation has no blocking uses, everything is fine.
100  if (!info.userToBlockingUses.contains(user))
101  continue;
102 
103  SmallPtrSet<OpOperand *, 4> &blockingUses = info.userToBlockingUses[user];
104  auto promotable = dyn_cast<PromotableOpInterface>(user);
105 
106  // An operation that has blocking uses must be promoted. If it is not
107  // promotable, destructuring must fail.
108  if (!promotable)
109  return {};
110 
111  SmallVector<OpOperand *> newBlockingUses;
112  // If the operation decides it cannot deal with removing the blocking uses,
113  // destructuring must fail.
114  if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses))
115  return {};
116 
117  // Then, register any new blocking uses for coming operations.
118  for (OpOperand *blockingUse : newBlockingUses) {
119  assert(llvm::is_contained(user->getResults(), blockingUse->get()));
120 
121  SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
122  info.userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
123  newUserBlockingUseSet.insert(blockingUse);
124  }
125  }
126 
127  return info;
128 }
129 
130 /// Performs the destructuring of a destructible slot given associated
131 /// destructuring information. The provided slot will be destructured in
132 /// subslots as specified by its allocator.
134  DestructurableAllocationOpInterface allocator,
135  RewriterBase &rewriter,
136  MemorySlotDestructuringInfo &info,
137  const SROAStatistics &statistics) {
138  RewriterBase::InsertionGuard guard(rewriter);
139 
142  allocator.destructure(slot, info.usedIndices, rewriter);
143 
144  if (statistics.slotsWithMemoryBenefit &&
145  slot.elementPtrs.size() != info.usedIndices.size())
146  (*statistics.slotsWithMemoryBenefit)++;
147 
148  if (statistics.maxSubelementAmount)
149  statistics.maxSubelementAmount->updateMax(slot.elementPtrs.size());
150 
151  SetVector<Operation *> usersToRewire;
152  for (Operation *user : llvm::make_first_range(info.userToBlockingUses))
153  usersToRewire.insert(user);
154  for (DestructurableAccessorOpInterface accessor : info.accessors)
155  usersToRewire.insert(accessor);
156  usersToRewire = mlir::topologicalSort(usersToRewire);
157 
159  for (Operation *toRewire : llvm::reverse(usersToRewire)) {
160  rewriter.setInsertionPointAfter(toRewire);
161  if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
162  if (accessor.rewire(slot, subslots, rewriter) == DeletionKind::Delete)
163  toErase.push_back(accessor);
164  continue;
165  }
166 
167  auto promotable = cast<PromotableOpInterface>(toRewire);
168  if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
169  rewriter) == DeletionKind::Delete)
170  toErase.push_back(promotable);
171  }
172 
173  for (Operation *toEraseOp : toErase)
174  rewriter.eraseOp(toEraseOp);
175 
176  assert(slot.ptr.use_empty() && "after destructuring, the original slot "
177  "pointer should no longer be used");
178 
179  LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
180  << "\n");
181 
182  if (statistics.destructuredAmount)
183  (*statistics.destructuredAmount)++;
184 
185  allocator.handleDestructuringComplete(slot, rewriter);
186 }
187 
190  RewriterBase &rewriter, SROAStatistics statistics) {
191  bool destructuredAny = false;
192 
193  for (DestructurableAllocationOpInterface allocator : allocators) {
194  for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
195  std::optional<MemorySlotDestructuringInfo> info =
197  if (!info)
198  continue;
199 
200  destructureSlot(slot, allocator, rewriter, *info, statistics);
201  destructuredAny = true;
202  }
203  }
204 
205  return success(destructuredAny);
206 }
207 
209 SROAPattern::matchAndRewrite(DestructurableAllocationOpInterface allocator,
210  PatternRewriter &rewriter) const {
212  return tryToDestructureMemorySlots({allocator}, rewriter, statistics);
213 }
214 
215 namespace {
216 
217 struct SROA : public impl::SROABase<SROA> {
218  using impl::SROABase<SROA>::SROABase;
219 
220  void runOnOperation() override {
221  Operation *scopeOp = getOperation();
222 
223  SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
224  &maxSubelementAmount};
225 
226  RewritePatternSet rewritePatterns(&getContext());
227  rewritePatterns.add<SROAPattern>(&getContext(), statistics);
228  FrozenRewritePatternSet frozen(std::move(rewritePatterns));
229 
230  if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen)))
231  signalPassFailure();
232  }
233 };
234 
235 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static std::optional< MemorySlotDestructuringInfo > computeDestructuringInfo(DestructurableMemorySlot &slot)
Computes information for slot destructuring.
Definition: SROA.cpp:46
static void destructureSlot(DestructurableMemorySlot &slot, DestructurableAllocationOpInterface allocator, RewriterBase &rewriter, MemorySlotDestructuringInfo &info, const SROAStatistics &statistics)
Performs the destructuring of a destructible slot given associated destructuring information.
Definition: SROA.cpp:133
This class represents a frozen set of patterns that can be processed by a pattern applicator.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:128
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Pattern applying SROA to the regions of the operations on which it matches.
Definition: SROA.h:35
LogicalResult matchAndRewrite(DestructurableAllocationOpInterface allocator, PatternRewriter &rewriter) const override
Definition: SROA.cpp:209
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:208
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Include the generated interface declarations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult tryToDestructureMemorySlots(ArrayRef< DestructurableAllocationOpInterface > allocators, RewriterBase &rewriter, SROAStatistics statistics={})
Attempts to destructure the slots of destructurable allocators.
Definition: SROA.cpp:188
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Delete
Delete the operation after promotion.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the type of the pointer that will be generated to access the ...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.
Statistics collected while applying SROA.
Definition: SROA.h:21
llvm::Statistic * maxSubelementAmount
Maximal number of sub-elements a successfully destructured slot initially had.
Definition: SROA.h:29
llvm::Statistic * slotsWithMemoryBenefit
Total amount of memory slots in which the destructured size was smaller than the total size after eli...
Definition: SROA.h:26
llvm::Statistic * destructuredAmount
Total amount of memory slots destructured.
Definition: SROA.h:23