MLIR  22.0.0git
SROA.cpp
Go to the documentation of this file.
1 //===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/SROA.h"
14 #include "mlir/Transforms/Passes.h"
15 #include "llvm/Support/DebugLog.h"
16 
17 namespace mlir {
18 #define GEN_PASS_DEF_SROA
19 #include "mlir/Transforms/Passes.h.inc"
20 } // namespace mlir
21 
22 #define DEBUG_TYPE "sroa"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 /// Information computed by destructurable memory slot analysis used to perform
29 /// actual destructuring of the slot. This struct is only constructed if
30 /// destructuring is possible, and contains the necessary data to perform it.
31 struct MemorySlotDestructuringInfo {
32  /// Set of the indices that are actually used when accessing the subelements.
33  SmallPtrSet<Attribute, 8> usedIndices;
34  /// Blocking uses of a given user of the memory slot that must be eliminated.
36  /// List of potentially indirect accessors of the memory slot that need
37  /// rewiring.
39 };
40 
41 } // namespace
42 
43 /// Computes information for slot destructuring. This will compute whether this
44 /// slot can be destructured and data to perform the destructuring. Returns
45 /// nothing if the slot cannot be destructured or if there is no useful work to
46 /// be done.
47 static std::optional<MemorySlotDestructuringInfo>
49  const DataLayout &dataLayout) {
50  assert(isa<DestructurableTypeInterface>(slot.elemType));
51 
52  if (slot.ptr.use_empty())
53  return {};
54 
55  MemorySlotDestructuringInfo info;
56 
57  SmallVector<MemorySlot> usedSafelyWorklist;
58 
59  auto scheduleAsBlockingUse = [&](OpOperand &use) {
60  SmallPtrSetImpl<OpOperand *> &blockingUses =
61  info.userToBlockingUses[use.getOwner()];
62  blockingUses.insert(&use);
63  };
64 
65  // Initialize the analysis with the immediate users of the slot.
66  for (OpOperand &use : slot.ptr.getUses()) {
67  if (auto accessor =
68  dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
69  if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist,
70  dataLayout)) {
71  info.accessors.push_back(accessor);
72  continue;
73  }
74  }
75 
76  // If it cannot be shown that the operation uses the slot safely, maybe it
77  // can be promoted out of using the slot?
78  scheduleAsBlockingUse(use);
79  }
80 
82  while (!usedSafelyWorklist.empty()) {
83  MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val();
84  for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) {
85  if (!visited.insert(&subslotUse).second)
86  continue;
87  Operation *subslotUser = subslotUse.getOwner();
88 
89  if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
90  if (succeeded(memOp.ensureOnlySafeAccesses(
91  mustBeUsedSafely, usedSafelyWorklist, dataLayout)))
92  continue;
93 
94  // If it cannot be shown that the operation uses the slot safely, maybe it
95  // can be promoted out of using the slot?
96  scheduleAsBlockingUse(subslotUse);
97  }
98  }
99 
100  SetVector<Operation *> forwardSlice;
101  mlir::getForwardSlice(slot.ptr, &forwardSlice);
102  for (Operation *user : forwardSlice) {
103  // If the next operation has no blocking uses, everything is fine.
104  auto it = info.userToBlockingUses.find(user);
105  if (it == info.userToBlockingUses.end())
106  continue;
107 
108  SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
109  auto promotable = dyn_cast<PromotableOpInterface>(user);
110 
111  // An operation that has blocking uses must be promoted. If it is not
112  // promotable, destructuring must fail.
113  if (!promotable)
114  return {};
115 
116  SmallVector<OpOperand *> newBlockingUses;
117  // If the operation decides it cannot deal with removing the blocking uses,
118  // destructuring must fail.
119  if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
120  return {};
121 
122  // Then, register any new blocking uses for coming operations.
123  for (OpOperand *blockingUse : newBlockingUses) {
124  assert(llvm::is_contained(user->getResults(), blockingUse->get()));
125 
126  SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
127  info.userToBlockingUses[blockingUse->getOwner()];
128  newUserBlockingUseSet.insert(blockingUse);
129  }
130  }
131 
132  return info;
133 }
134 
135 /// Performs the destructuring of a destructible slot given associated
136 /// destructuring information. The provided slot will be destructured in
137 /// subslots as specified by its allocator.
138 static void destructureSlot(
140  DestructurableAllocationOpInterface allocator, OpBuilder &builder,
141  const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
143  const SROAStatistics &statistics) {
144  OpBuilder::InsertionGuard guard(builder);
145 
148  allocator.destructure(slot, info.usedIndices, builder, newAllocators);
149 
150  if (statistics.slotsWithMemoryBenefit &&
151  slot.subelementTypes.size() != info.usedIndices.size())
152  (*statistics.slotsWithMemoryBenefit)++;
153 
154  if (statistics.maxSubelementAmount)
155  statistics.maxSubelementAmount->updateMax(slot.subelementTypes.size());
156 
157  SetVector<Operation *> usersToRewire;
158  usersToRewire.insert_range(llvm::make_first_range(info.userToBlockingUses));
159  usersToRewire.insert_range(info.accessors);
160  usersToRewire = mlir::topologicalSort(usersToRewire);
161 
163  for (Operation *toRewire : llvm::reverse(usersToRewire)) {
164  builder.setInsertionPointAfter(toRewire);
165  if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
166  if (accessor.rewire(slot, subslots, builder, dataLayout) ==
168  toErase.push_back(accessor);
169  continue;
170  }
171 
172  auto promotable = cast<PromotableOpInterface>(toRewire);
173  if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
174  builder) == DeletionKind::Delete)
175  toErase.push_back(promotable);
176  }
177 
178  for (Operation *toEraseOp : toErase)
179  toEraseOp->erase();
180 
181  assert(slot.ptr.use_empty() && "after destructuring, the original slot "
182  "pointer should no longer be used");
183 
184  LDBG() << "Destructured memory slot: " << slot.ptr;
185 
186  if (statistics.destructuredAmount)
187  (*statistics.destructuredAmount)++;
188 
189  std::optional<DestructurableAllocationOpInterface> newAllocator =
190  allocator.handleDestructuringComplete(slot, builder);
191  // Add newly created allocators to the worklist for further processing.
192  if (newAllocator)
193  newAllocators.push_back(*newAllocator);
194 }
195 
198  OpBuilder &builder, const DataLayout &dataLayout,
199  SROAStatistics statistics) {
200  bool destructuredAny = false;
201 
204  newWorkList.reserve(allocators.size());
205  // Destructuring a slot can allow for further destructuring of other
206  // slots, destructuring is tried until no destructuring succeeds.
207  while (true) {
208  bool changesInThisRound = false;
209 
210  for (DestructurableAllocationOpInterface allocator : workList) {
211  bool destructuredAnySlot = false;
212  for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
213  std::optional<MemorySlotDestructuringInfo> info =
214  computeDestructuringInfo(slot, dataLayout);
215  if (!info)
216  continue;
217 
218  destructureSlot(slot, allocator, builder, dataLayout, *info,
219  newWorkList, statistics);
220  destructuredAnySlot = true;
221 
222  // A break is required, since destructuring a slot may invalidate the
223  // remaning slots of an allocator.
224  break;
225  }
226  if (!destructuredAnySlot)
227  newWorkList.push_back(allocator);
228  changesInThisRound |= destructuredAnySlot;
229  }
230 
231  if (!changesInThisRound)
232  break;
233  destructuredAny |= changesInThisRound;
234 
235  // Swap the vector's backing memory and clear the entries in newWorkList
236  // afterwards. This ensures that additional heap allocations can be avoided.
237  workList.swap(newWorkList);
238  newWorkList.clear();
239  }
240 
241  return success(destructuredAny);
242 }
243 
244 namespace {
245 
246 struct SROA : public impl::SROABase<SROA> {
247  using impl::SROABase<SROA>::SROABase;
248 
249  void runOnOperation() override {
250  Operation *scopeOp = getOperation();
251 
252  SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
253  &maxSubelementAmount};
254 
255  auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
256  const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
257  bool changed = false;
258 
259  for (Region &region : scopeOp->getRegions()) {
260  if (region.getBlocks().empty())
261  continue;
262 
263  OpBuilder builder(&region.front(), region.front().begin());
264 
266  // Build a list of allocators to attempt to destructure the slots of.
267  region.walk([&](DestructurableAllocationOpInterface allocator) {
268  allocators.emplace_back(allocator);
269  });
270 
271  // Attempt to destructure as many slots as possible.
272  if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
273  statistics)))
274  changed = true;
275  }
276  if (!changed)
277  markAllAnalysesPreserved();
278  }
279 };
280 
281 } // namespace
static void destructureSlot(DestructurableMemorySlot &slot, DestructurableAllocationOpInterface allocator, OpBuilder &builder, const DataLayout &dataLayout, MemorySlotDestructuringInfo &info, SmallVectorImpl< DestructurableAllocationOpInterface > &newAllocators, const SROAStatistics &statistics)
Performs the destructuring of a destructible slot given associated destructuring information.
Definition: SROA.cpp:138
static std::optional< MemorySlotDestructuringInfo > computeDestructuringInfo(DestructurableMemorySlot &slot, const DataLayout &dataLayout)
Computes information for slot destructuring.
Definition: SROA.cpp:48
The main mechanism for performing data layout queries.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
@ Delete
Delete the operation after promotion.
LogicalResult tryToDestructureMemorySlots(ArrayRef< DestructurableAllocationOpInterface > allocators, OpBuilder &builder, const DataLayout &dataLayout, SROAStatistics statistics={})
Attempts to destructure the slots of destructurable allocators.
Definition: SROA.cpp:196
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > subelementTypes
Maps an index within the memory slot to the corresponding subelement type.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.
Statistics collected while applying SROA.
Definition: SROA.h:19
llvm::Statistic * maxSubelementAmount
Maximal number of sub-elements a successfully destructured slot initially had.
Definition: SROA.h:27
llvm::Statistic * slotsWithMemoryBenefit
Total amount of memory slots in which the destructured size was smaller than the total size after eli...
Definition: SROA.h:24
llvm::Statistic * destructuredAmount
Total amount of memory slots destructured.
Definition: SROA.h:21