MLIR  22.0.0git
BufferViewFlowAnalysis.cpp
Go to the documentation of this file.
1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "llvm/ADT/SetOperations.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 
21 //===----------------------------------------------------------------------===//
22 // BufferViewFlowAnalysis
23 //===----------------------------------------------------------------------===//
24 
25 /// Constructs a new alias analysis using the op provided.
27 
32  queue.push_back(value);
33  while (!queue.empty()) {
34  Value currentValue = queue.pop_back_val();
35  if (result.insert(currentValue).second) {
36  auto it = map.find(currentValue);
37  if (it != map.end()) {
38  for (Value aliasValue : it->second)
39  queue.push_back(aliasValue);
40  }
41  }
42  }
43  return result;
44 }
45 
46 /// Find all immediate and indirect dependent buffers this value could
47 /// potentially have. Note that the resulting set will also contain the value
48 /// provided as it is a dependent alias of itself.
51  return resolveValues(dependencies, rootValue);
52 }
53 
56  return resolveValues(reverseDependencies, rootValue);
57 }
58 
59 /// Removes the given values from all alias sets.
61  for (auto &entry : dependencies)
62  llvm::set_subtract(entry.second, aliasValues);
63 }
64 
66  dependencies[to] = dependencies[from];
67  dependencies.erase(from);
68 
69  for (auto &[_, value] : dependencies) {
70  if (value.contains(from)) {
71  value.insert(to);
72  value.erase(from);
73  }
74  }
75 }
76 
77 /// This function constructs a mapping from values to its immediate
78 /// dependencies. It iterates over all blocks, gets their predecessors,
79 /// determines the values that will be passed to the corresponding block
80 /// arguments and inserts them into the underlying map. Furthermore, it wires
81 /// successor regions and branch-like return operations from nested regions.
82 void BufferViewFlowAnalysis::build(Operation *op) {
83  // Registers all dependencies of the given values.
84  auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
85  for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
86  this->dependencies[value].insert(dep);
87  this->reverseDependencies[dep].insert(value);
88  }
89  };
90 
91  // Mark all buffer results and buffer region entry block arguments of the
92  // given op as terminals.
93  auto populateTerminalValues = [&](Operation *op) {
94  for (Value v : op->getResults())
95  if (isa<BaseMemRefType>(v.getType()))
96  this->terminals.insert(v);
97  for (Region &r : op->getRegions())
98  for (BlockArgument v : r.getArguments())
99  if (isa<BaseMemRefType>(v.getType()))
100  this->terminals.insert(v);
101  };
102 
103  op->walk([&](Operation *op) {
104  // Query BufferViewFlowOpInterface. If the op does not implement that
105  // interface, try to infer the dependencies from other interfaces that the
106  // op may implement.
107  if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
108  bufferViewFlowOp.populateDependencies(registerDependencies);
109  for (Value v : op->getResults())
110  if (isa<BaseMemRefType>(v.getType()) &&
111  bufferViewFlowOp.mayBeTerminalBuffer(v))
112  this->terminals.insert(v);
113  for (Region &r : op->getRegions())
114  for (BlockArgument v : r.getArguments())
115  if (isa<BaseMemRefType>(v.getType()) &&
116  bufferViewFlowOp.mayBeTerminalBuffer(v))
117  this->terminals.insert(v);
118  return WalkResult::advance();
119  }
120 
121  // Add additional dependencies created by view changes to the alias list.
122  if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
123  registerDependencies(viewInterface.getViewSource(),
124  viewInterface.getViewDest());
125  return WalkResult::advance();
126  }
127 
128  if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
129  // Query all branch interfaces to link block argument dependencies.
130  Block *parentBlock = branchInterface->getBlock();
131  for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
132  it != e; ++it) {
133  // Query the branch op interface to get the successor operands.
134  auto successorOperands =
135  branchInterface.getSuccessorOperands(it.getIndex());
136  // Build the actual mapping of values to their immediate dependencies.
137  registerDependencies(successorOperands.getForwardedOperands(),
138  (*it)->getArguments().drop_front(
139  successorOperands.getProducedOperandCount()));
140  }
141  return WalkResult::advance();
142  }
143 
144  if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
145  // Query the RegionBranchOpInterface to find potential successor regions.
146  // Extract all entry regions and wire all initial entry successor inputs.
147  SmallVector<RegionSuccessor, 2> entrySuccessors;
148  regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
149  entrySuccessors);
150  for (RegionSuccessor &entrySuccessor : entrySuccessors) {
151  // Wire the entry region's successor arguments with the initial
152  // successor inputs.
153  registerDependencies(
154  regionInterface.getEntrySuccessorOperands(entrySuccessor),
155  entrySuccessor.getSuccessorInputs());
156  }
157 
158  // Wire flow between regions and from region exits.
159  for (Region &region : regionInterface->getRegions()) {
160  // Iterate over all successor region entries that are reachable from the
161  // current region.
162  SmallVector<RegionSuccessor, 2> successorRegions;
163  regionInterface.getSuccessorRegions(region, successorRegions);
164  for (RegionSuccessor &successorRegion : successorRegions) {
165  // Iterate over all immediate terminator operations and wire the
166  // successor inputs with the successor operands of each terminator.
167  for (Block &block : region)
168  if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
169  block.getTerminator()))
170  registerDependencies(
171  terminator.getSuccessorOperands(successorRegion),
172  successorRegion.getSuccessorInputs());
173  }
174  }
175 
176  return WalkResult::advance();
177  }
178 
179  // Region terminators are handled together with RegionBranchOpInterface.
180  if (isa<RegionBranchTerminatorOpInterface>(op))
181  return WalkResult::advance();
182 
183  if (isa<CallOpInterface>(op)) {
184  // This is an intra-function analysis. We have no information about other
185  // functions. Conservatively assume that each operand may alias with each
186  // result. Also mark the results are terminals because the function could
187  // return newly allocated buffers.
188  populateTerminalValues(op);
189  for (Value operand : op->getOperands())
190  for (Value result : op->getResults())
191  registerDependencies({operand}, {result});
192  return WalkResult::advance();
193  }
194 
195  // We have no information about unknown ops.
196  populateTerminalValues(op);
197 
198  return WalkResult::advance();
199  });
200 }
201 
203  assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
204  return terminals.contains(value);
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // BufferOriginAnalysis
209 //===----------------------------------------------------------------------===//
210 
211 /// Return "true" if the given value is the result of a memory allocation.
212 static bool hasAllocateSideEffect(Value v) {
213  Operation *op = v.getDefiningOp();
214  if (!op)
215  return false;
216  return hasEffect<MemoryEffects::Allocate>(op, v);
217 }
218 
219 /// Return "true" if the given value is a function block argument.
220 static bool isFunctionArgument(Value v) {
221  auto bbArg = dyn_cast<BlockArgument>(v);
222  if (!bbArg)
223  return false;
224  Block *b = bbArg.getOwner();
225  auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
226  if (!funcOp)
227  return false;
228  return bbArg.getOwner() == &funcOp.getFunctionBody().front();
229 }
230 
231 /// Given a memref value, return the "base" value by skipping over all
232 /// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
233 static Value getViewBase(Value value) {
234  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
235  if (value != viewLikeOp.getViewDest()) {
236  break;
237  }
238  value = viewLikeOp.getViewSource();
239  }
240  return value;
241 }
242 
244 
246  assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
247  assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
248 
249  // Skip over all view-like ops.
250  v1 = getViewBase(v1);
251  v2 = getViewBase(v2);
252 
253  // Fast path: If both buffers are the same SSA value, we can be sure that
254  // they originate from the same allocation.
255  if (v1 == v2)
256  return true;
257 
258  // Compute the SSA values from which the buffers `v1` and `v2` originate.
259  SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
260  SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
261 
262  // Originating buffers are "terminal" if they could not be traced back any
263  // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
264  // - function block arguments
265  // - values defined by allocation ops such as "memref.alloc"
266  // - values defined by ops that are unknown to the buffer view flow analysis
267  // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
268  SmallPtrSet<Value, 16> terminal1, terminal2;
269 
270  // While gathering terminal buffers, keep track of whether all terminal
271  // buffers are newly allocated buffer or function entry arguments.
272  bool allAllocs1 = true, allAllocs2 = true;
273  bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
274 
275  // Helper function that gathers terminal buffers among `origin`.
276  auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
277  SmallPtrSet<Value, 16> &terminal,
278  bool &allAllocs,
279  bool &allAllocsOrFuncEntryArgs) {
280  for (Value v : origin) {
281  if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
282  terminal.insert(v);
283  allAllocs &= hasAllocateSideEffect(v);
284  allAllocsOrFuncEntryArgs &=
286  }
287  }
288  assert(!terminal.empty() && "expected non-empty terminal set");
289  };
290 
291  // Gather terminal buffers for `v1` and `v2`.
292  gatherTerminalBuffers(origin1, terminal1, allAllocs1,
293  allAllocsOrFuncEntryArgs1);
294  gatherTerminalBuffers(origin2, terminal2, allAllocs2,
295  allAllocsOrFuncEntryArgs2);
296 
297  // If both `v1` and `v2` have a single matching terminal buffer, they are
298  // guaranteed to originate from the same buffer allocation.
299  if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
300  *terminal1.begin() == *terminal2.begin())
301  return true;
302 
303  // At least one of the two values has multiple terminals.
304 
305  // Check if there is overlap between the terminal buffers of `v1` and `v2`.
306  bool distinctTerminalSets = true;
307  for (Value v : terminal1)
308  distinctTerminalSets &= !terminal2.contains(v);
309  // If there is overlap between the terminal buffers of `v1` and `v2`, we
310  // cannot make an accurate decision without further analysis.
311  if (!distinctTerminalSets)
312  return std::nullopt;
313 
314  // If `v1` originates from only allocs, and `v2` is guaranteed to originate
315  // from different allocations (that is guaranteed if `v2` originates from
316  // only distinct allocs or function entry arguments), we can be sure that
317  // `v1` and `v2` originate from different allocations. The same argument can
318  // be made when swapping `v1` and `v2`.
319  bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
320  bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
321  if (isolatedAlloc1 || isolatedAlloc2)
322  return false;
323 
324  // Otherwise: We do not know whether `v1` and `v2` originate from the same
325  // allocation or not.
326  // TODO: Function arguments are currently handled conservatively. We assume
327  // that they could be the same allocation.
328  // TODO: Terminals other than allocations and function arguments are
329  // currently handled conservatively. We assume that they could be the same
330  // allocation. E.g., we currently return "nullopt" for values that originate
331  // from different "memref.get_global" ops (with different symbols).
332  return std::nullopt;
333 }
static bool isFunctionArgument(Value v)
Return "true" if the given value is a function block argument.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static BufferViewFlowAnalysis::ValueSetT resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value)
static bool hasAllocateSideEffect(Value v)
Return "true" if the given value is the result of a memory allocation.
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
succ_iterator succ_end()
Definition: Block.h:269
succ_iterator succ_begin()
Definition: Block.h:268
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
std::optional< bool > isSameAllocation(Value v1, Value v2)
Return "true" if v1 and v2 originate from the same buffer allocation.
BufferViewFlowAnalysis(Operation *op)
Constructs a new alias analysis using the op provided.
void remove(const SetVector< Value > &aliasValues)
Removes the given values from all alias sets.
ValueSetT resolve(Value value) const
Find all immediate and indirect views upon this value.
void rename(Value from, Value to)
Replaces all occurrences of 'from' in the internal datastructures with 'to'.
bool mayBeTerminalBuffer(Value value) const
Returns "true" if the given value may be a terminal.
ValueSetT resolveReverse(Value value) const
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition: Remarks.h:497
Include the generated interface declarations.