MLIR  19.0.0git
SROA.cpp
Go to the documentation of this file.
1 //===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/SROA.h"
13 #include "mlir/Transforms/Passes.h"
14 
15 namespace mlir {
16 #define GEN_PASS_DEF_SROA
17 #include "mlir/Transforms/Passes.h.inc"
18 } // namespace mlir
19 
20 #define DEBUG_TYPE "sroa"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 /// Information computed by destructurable memory slot analysis used to perform
27 /// actual destructuring of the slot. This struct is only constructed if
28 /// destructuring is possible, and contains the necessary data to perform it.
29 struct MemorySlotDestructuringInfo {
30  /// Set of the indices that are actually used when accessing the subelements.
31  SmallPtrSet<Attribute, 8> usedIndices;
32  /// Blocking uses of a given user of the memory slot that must be eliminated.
34  /// List of potentially indirect accessors of the memory slot that need
35  /// rewiring.
37 };
38 
39 } // namespace
40 
41 /// Computes information for slot destructuring. This will compute whether this
42 /// slot can be destructured and data to perform the destructuring. Returns
43 /// nothing if the slot cannot be destructured or if there is no useful work to
44 /// be done.
45 static std::optional<MemorySlotDestructuringInfo>
47  const DataLayout &dataLayout) {
48  assert(isa<DestructurableTypeInterface>(slot.elemType));
49 
50  if (slot.ptr.use_empty())
51  return {};
52 
53  MemorySlotDestructuringInfo info;
54 
55  SmallVector<MemorySlot> usedSafelyWorklist;
56 
57  auto scheduleAsBlockingUse = [&](OpOperand &use) {
58  SmallPtrSetImpl<OpOperand *> &blockingUses =
59  info.userToBlockingUses.getOrInsertDefault(use.getOwner());
60  blockingUses.insert(&use);
61  };
62 
63  // Initialize the analysis with the immediate users of the slot.
64  for (OpOperand &use : slot.ptr.getUses()) {
65  if (auto accessor =
66  dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
67  if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist,
68  dataLayout)) {
69  info.accessors.push_back(accessor);
70  continue;
71  }
72  }
73 
74  // If it cannot be shown that the operation uses the slot safely, maybe it
75  // can be promoted out of using the slot?
76  scheduleAsBlockingUse(use);
77  }
78 
80  while (!usedSafelyWorklist.empty()) {
81  MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val();
82  for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) {
83  if (!visited.insert(&subslotUse).second)
84  continue;
85  Operation *subslotUser = subslotUse.getOwner();
86 
87  if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
88  if (succeeded(memOp.ensureOnlySafeAccesses(
89  mustBeUsedSafely, usedSafelyWorklist, dataLayout)))
90  continue;
91 
92  // If it cannot be shown that the operation uses the slot safely, maybe it
93  // can be promoted out of using the slot?
94  scheduleAsBlockingUse(subslotUse);
95  }
96  }
97 
98  SetVector<Operation *> forwardSlice;
99  mlir::getForwardSlice(slot.ptr, &forwardSlice);
100  for (Operation *user : forwardSlice) {
101  // If the next operation has no blocking uses, everything is fine.
102  if (!info.userToBlockingUses.contains(user))
103  continue;
104 
105  SmallPtrSet<OpOperand *, 4> &blockingUses = info.userToBlockingUses[user];
106  auto promotable = dyn_cast<PromotableOpInterface>(user);
107 
108  // An operation that has blocking uses must be promoted. If it is not
109  // promotable, destructuring must fail.
110  if (!promotable)
111  return {};
112 
113  SmallVector<OpOperand *> newBlockingUses;
114  // If the operation decides it cannot deal with removing the blocking uses,
115  // destructuring must fail.
116  if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
117  return {};
118 
119  // Then, register any new blocking uses for coming operations.
120  for (OpOperand *blockingUse : newBlockingUses) {
121  assert(llvm::is_contained(user->getResults(), blockingUse->get()));
122 
123  SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
124  info.userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
125  newUserBlockingUseSet.insert(blockingUse);
126  }
127  }
128 
129  return info;
130 }
131 
132 /// Performs the destructuring of a destructible slot given associated
133 /// destructuring information. The provided slot will be destructured in
134 /// subslots as specified by its allocator.
136  DestructurableAllocationOpInterface allocator,
137  RewriterBase &rewriter,
138  const DataLayout &dataLayout,
139  MemorySlotDestructuringInfo &info,
140  const SROAStatistics &statistics) {
141  RewriterBase::InsertionGuard guard(rewriter);
142 
145  allocator.destructure(slot, info.usedIndices, rewriter);
146 
147  if (statistics.slotsWithMemoryBenefit &&
148  slot.elementPtrs.size() != info.usedIndices.size())
149  (*statistics.slotsWithMemoryBenefit)++;
150 
151  if (statistics.maxSubelementAmount)
152  statistics.maxSubelementAmount->updateMax(slot.elementPtrs.size());
153 
154  SetVector<Operation *> usersToRewire;
155  for (Operation *user : llvm::make_first_range(info.userToBlockingUses))
156  usersToRewire.insert(user);
157  for (DestructurableAccessorOpInterface accessor : info.accessors)
158  usersToRewire.insert(accessor);
159  usersToRewire = mlir::topologicalSort(usersToRewire);
160 
162  for (Operation *toRewire : llvm::reverse(usersToRewire)) {
163  rewriter.setInsertionPointAfter(toRewire);
164  if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
165  if (accessor.rewire(slot, subslots, rewriter, dataLayout) ==
167  toErase.push_back(accessor);
168  continue;
169  }
170 
171  auto promotable = cast<PromotableOpInterface>(toRewire);
172  if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
173  rewriter) == DeletionKind::Delete)
174  toErase.push_back(promotable);
175  }
176 
177  for (Operation *toEraseOp : toErase)
178  rewriter.eraseOp(toEraseOp);
179 
180  assert(slot.ptr.use_empty() && "after destructuring, the original slot "
181  "pointer should no longer be used");
182 
183  LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
184  << "\n");
185 
186  if (statistics.destructuredAmount)
187  (*statistics.destructuredAmount)++;
188 
189  allocator.handleDestructuringComplete(slot, rewriter);
190 }
191 
194  RewriterBase &rewriter, const DataLayout &dataLayout,
195  SROAStatistics statistics) {
196  bool destructuredAny = false;
197 
198  for (DestructurableAllocationOpInterface allocator : allocators) {
199  for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
200  std::optional<MemorySlotDestructuringInfo> info =
201  computeDestructuringInfo(slot, dataLayout);
202  if (!info)
203  continue;
204 
205  destructureSlot(slot, allocator, rewriter, dataLayout, *info, statistics);
206  destructuredAny = true;
207  }
208  }
209 
210  return success(destructuredAny);
211 }
212 
213 namespace {
214 
215 struct SROA : public impl::SROABase<SROA> {
216  using impl::SROABase<SROA>::SROABase;
217 
218  void runOnOperation() override {
219  Operation *scopeOp = getOperation();
220 
221  SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
222  &maxSubelementAmount};
223 
224  auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
225  const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
226  bool changed = false;
227 
228  for (Region &region : scopeOp->getRegions()) {
229  if (region.getBlocks().empty())
230  continue;
231 
232  OpBuilder builder(&region.front(), region.front().begin());
233  IRRewriter rewriter(builder);
234 
235  // Destructuring a slot can allow for further destructuring of other
236  // slots, destructuring is tried until no destructuring succeeds.
237  while (true) {
239  // Build a list of allocators to attempt to destructure the slots of.
240  // TODO: Update list on the fly to avoid repeated visiting of the same
241  // allocators.
242  region.walk([&](DestructurableAllocationOpInterface allocator) {
243  allocators.emplace_back(allocator);
244  });
245 
246  if (failed(tryToDestructureMemorySlots(allocators, rewriter, dataLayout,
247  statistics)))
248  break;
249 
250  changed = true;
251  }
252  }
253  if (!changed)
254  markAllAnalysesPreserved();
255  }
256 };
257 
258 } // namespace
static void destructureSlot(DestructurableMemorySlot &slot, DestructurableAllocationOpInterface allocator, RewriterBase &rewriter, const DataLayout &dataLayout, MemorySlotDestructuringInfo &info, const SROAStatistics &statistics)
Performs the destructuring of a destructible slot given associated destructuring information.
Definition: SROA.cpp:135
static std::optional< MemorySlotDestructuringInfo > computeDestructuringInfo(DestructurableMemorySlot &slot, const DataLayout &dataLayout)
Computes information for slot destructuring.
Definition: SROA.cpp:46
The main mechanism for performing data layout queries.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:756
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:208
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Include the generated interface declarations.
LogicalResult tryToDestructureMemorySlots(ArrayRef< DestructurableAllocationOpInterface > allocators, RewriterBase &rewriter, const DataLayout &dataLayout, SROAStatistics statistics={})
Attempts to destructure the slots of destructurable allocators.
Definition: SROA.cpp:192
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Delete
Delete the operation after promotion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the corresponding subelement type.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.
Statistics collected while applying SROA.
Definition: SROA.h:19
llvm::Statistic * maxSubelementAmount
Maximal number of sub-elements a successfully destructured slot initially had.
Definition: SROA.h:27
llvm::Statistic * slotsWithMemoryBenefit
Total amount of memory slots in which the destructured size was smaller than the total size after eli...
Definition: SROA.h:24
llvm::Statistic * destructuredAmount
Total amount of memory slots destructured.
Definition: SROA.h:21