MLIR  19.0.0git
SROA.cpp
Go to the documentation of this file.
1 //===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/SROA.h"
13 #include "mlir/Transforms/Passes.h"
14 
15 namespace mlir {
16 #define GEN_PASS_DEF_SROA
17 #include "mlir/Transforms/Passes.h.inc"
18 } // namespace mlir
19 
20 #define DEBUG_TYPE "sroa"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 /// Information computed by destructurable memory slot analysis used to perform
27 /// actual destructuring of the slot. This struct is only constructed if
28 /// destructuring is possible, and contains the necessary data to perform it.
29 struct MemorySlotDestructuringInfo {
30  /// Set of the indices that are actually used when accessing the subelements.
31  SmallPtrSet<Attribute, 8> usedIndices;
32  /// Blocking uses of a given user of the memory slot that must be eliminated.
34  /// List of potentially indirect accessors of the memory slot that need
35  /// rewiring.
37 };
38 
39 } // namespace
40 
41 /// Computes information for slot destructuring. This will compute whether this
42 /// slot can be destructured and data to perform the destructuring. Returns
43 /// nothing if the slot cannot be destructured or if there is no useful work to
44 /// be done.
45 static std::optional<MemorySlotDestructuringInfo>
47  const DataLayout &dataLayout) {
48  assert(isa<DestructurableTypeInterface>(slot.elemType));
49 
50  if (slot.ptr.use_empty())
51  return {};
52 
53  MemorySlotDestructuringInfo info;
54 
55  SmallVector<MemorySlot> usedSafelyWorklist;
56 
57  auto scheduleAsBlockingUse = [&](OpOperand &use) {
58  SmallPtrSetImpl<OpOperand *> &blockingUses =
59  info.userToBlockingUses.getOrInsertDefault(use.getOwner());
60  blockingUses.insert(&use);
61  };
62 
63  // Initialize the analysis with the immediate users of the slot.
64  for (OpOperand &use : slot.ptr.getUses()) {
65  if (auto accessor =
66  dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
67  if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist,
68  dataLayout)) {
69  info.accessors.push_back(accessor);
70  continue;
71  }
72  }
73 
74  // If it cannot be shown that the operation uses the slot safely, maybe it
75  // can be promoted out of using the slot?
76  scheduleAsBlockingUse(use);
77  }
78 
80  while (!usedSafelyWorklist.empty()) {
81  MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val();
82  for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) {
83  if (!visited.insert(&subslotUse).second)
84  continue;
85  Operation *subslotUser = subslotUse.getOwner();
86 
87  if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
88  if (succeeded(memOp.ensureOnlySafeAccesses(
89  mustBeUsedSafely, usedSafelyWorklist, dataLayout)))
90  continue;
91 
92  // If it cannot be shown that the operation uses the slot safely, maybe it
93  // can be promoted out of using the slot?
94  scheduleAsBlockingUse(subslotUse);
95  }
96  }
97 
98  SetVector<Operation *> forwardSlice;
99  mlir::getForwardSlice(slot.ptr, &forwardSlice);
100  for (Operation *user : forwardSlice) {
101  // If the next operation has no blocking uses, everything is fine.
102  if (!info.userToBlockingUses.contains(user))
103  continue;
104 
105  SmallPtrSet<OpOperand *, 4> &blockingUses = info.userToBlockingUses[user];
106  auto promotable = dyn_cast<PromotableOpInterface>(user);
107 
108  // An operation that has blocking uses must be promoted. If it is not
109  // promotable, destructuring must fail.
110  if (!promotable)
111  return {};
112 
113  SmallVector<OpOperand *> newBlockingUses;
114  // If the operation decides it cannot deal with removing the blocking uses,
115  // destructuring must fail.
116  if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
117  return {};
118 
119  // Then, register any new blocking uses for coming operations.
120  for (OpOperand *blockingUse : newBlockingUses) {
121  assert(llvm::is_contained(user->getResults(), blockingUse->get()));
122 
123  SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
124  info.userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
125  newUserBlockingUseSet.insert(blockingUse);
126  }
127  }
128 
129  return info;
130 }
131 
132 /// Performs the destructuring of a destructible slot given associated
133 /// destructuring information. The provided slot will be destructured in
134 /// subslots as specified by its allocator.
135 static void destructureSlot(
137  DestructurableAllocationOpInterface allocator, OpBuilder &builder,
138  const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
140  const SROAStatistics &statistics) {
141  OpBuilder::InsertionGuard guard(builder);
142 
145  allocator.destructure(slot, info.usedIndices, builder, newAllocators);
146 
147  if (statistics.slotsWithMemoryBenefit &&
148  slot.elementPtrs.size() != info.usedIndices.size())
149  (*statistics.slotsWithMemoryBenefit)++;
150 
151  if (statistics.maxSubelementAmount)
152  statistics.maxSubelementAmount->updateMax(slot.elementPtrs.size());
153 
154  SetVector<Operation *> usersToRewire;
155  for (Operation *user : llvm::make_first_range(info.userToBlockingUses))
156  usersToRewire.insert(user);
157  for (DestructurableAccessorOpInterface accessor : info.accessors)
158  usersToRewire.insert(accessor);
159  usersToRewire = mlir::topologicalSort(usersToRewire);
160 
162  for (Operation *toRewire : llvm::reverse(usersToRewire)) {
163  builder.setInsertionPointAfter(toRewire);
164  if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
165  if (accessor.rewire(slot, subslots, builder, dataLayout) ==
167  toErase.push_back(accessor);
168  continue;
169  }
170 
171  auto promotable = cast<PromotableOpInterface>(toRewire);
172  if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
173  builder) == DeletionKind::Delete)
174  toErase.push_back(promotable);
175  }
176 
177  for (Operation *toEraseOp : toErase)
178  toEraseOp->erase();
179 
180  assert(slot.ptr.use_empty() && "after destructuring, the original slot "
181  "pointer should no longer be used");
182 
183  LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
184  << "\n");
185 
186  if (statistics.destructuredAmount)
187  (*statistics.destructuredAmount)++;
188 
189  std::optional<DestructurableAllocationOpInterface> newAllocator =
190  allocator.handleDestructuringComplete(slot, builder);
191  // Add newly created allocators to the worklist for further processing.
192  if (newAllocator)
193  newAllocators.push_back(*newAllocator);
194 }
195 
198  OpBuilder &builder, const DataLayout &dataLayout,
199  SROAStatistics statistics) {
200  bool destructuredAny = false;
201 
202  SmallVector<DestructurableAllocationOpInterface> workList(allocators.begin(),
203  allocators.end());
205  newWorkList.reserve(allocators.size());
206  // Destructuring a slot can allow for further destructuring of other
207  // slots, destructuring is tried until no destructuring succeeds.
208  while (true) {
209  bool changesInThisRound = false;
210 
211  for (DestructurableAllocationOpInterface allocator : workList) {
212  bool destructuredAnySlot = false;
213  for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
214  std::optional<MemorySlotDestructuringInfo> info =
215  computeDestructuringInfo(slot, dataLayout);
216  if (!info)
217  continue;
218 
219  destructureSlot(slot, allocator, builder, dataLayout, *info,
220  newWorkList, statistics);
221  destructuredAnySlot = true;
222 
223  // A break is required, since destructuring a slot may invalidate the
224  // remaning slots of an allocator.
225  break;
226  }
227  if (!destructuredAnySlot)
228  newWorkList.push_back(allocator);
229  changesInThisRound |= destructuredAnySlot;
230  }
231 
232  if (!changesInThisRound)
233  break;
234  destructuredAny |= changesInThisRound;
235 
236  // Swap the vector's backing memory and clear the entries in newWorkList
237  // afterwards. This ensures that additional heap allocations can be avoided.
238  workList.swap(newWorkList);
239  newWorkList.clear();
240  }
241 
242  return success(destructuredAny);
243 }
244 
245 namespace {
246 
247 struct SROA : public impl::SROABase<SROA> {
248  using impl::SROABase<SROA>::SROABase;
249 
250  void runOnOperation() override {
251  Operation *scopeOp = getOperation();
252 
253  SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
254  &maxSubelementAmount};
255 
256  auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
257  const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
258  bool changed = false;
259 
260  for (Region &region : scopeOp->getRegions()) {
261  if (region.getBlocks().empty())
262  continue;
263 
264  OpBuilder builder(&region.front(), region.front().begin());
265 
267  // Build a list of allocators to attempt to destructure the slots of.
268  region.walk([&](DestructurableAllocationOpInterface allocator) {
269  allocators.emplace_back(allocator);
270  });
271 
272  // Attempt to destructure as many slots as possible.
273  if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
274  statistics)))
275  changed = true;
276  }
277  if (!changed)
278  markAllAnalysesPreserved();
279  }
280 };
281 
282 } // namespace
static void destructureSlot(DestructurableMemorySlot &slot, DestructurableAllocationOpInterface allocator, OpBuilder &builder, const DataLayout &dataLayout, MemorySlotDestructuringInfo &info, SmallVectorImpl< DestructurableAllocationOpInterface > &newAllocators, const SROAStatistics &statistics)
Performs the destructuring of a destructible slot given associated destructuring information.
Definition: SROA.cpp:135
static std::optional< MemorySlotDestructuringInfo > computeDestructuringInfo(DestructurableMemorySlot &slot, const DataLayout &dataLayout)
Computes information for slot destructuring.
Definition: SROA.cpp:46
The main mechanism for performing data layout queries.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Include the generated interface declarations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Delete
Delete the operation after promotion.
LogicalResult tryToDestructureMemorySlots(ArrayRef< DestructurableAllocationOpInterface > allocators, OpBuilder &builder, const DataLayout &dataLayout, SROAStatistics statistics={})
Attempts to destructure the slots of destructurable allocators.
Definition: SROA.cpp:196
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the corresponding subelement type.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.
Statistics collected while applying SROA.
Definition: SROA.h:19
llvm::Statistic * maxSubelementAmount
Maximal number of sub-elements a successfully destructured slot initially had.
Definition: SROA.h:27
llvm::Statistic * slotsWithMemoryBenefit
Total amount of memory slots in which the destructured size was smaller than the total size after eli...
Definition: SROA.h:24
llvm::Statistic * destructuredAmount
Total amount of memory slots destructured.
Definition: SROA.h:21