MLIR  19.0.0git
BufferViewFlowAnalysis.cpp
Go to the documentation of this file.
1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "llvm/ADT/SetOperations.h"
17 #include "llvm/ADT/SetVector.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // BufferViewFlowAnalysis
24 //===----------------------------------------------------------------------===//
25 
26 /// Constructs a new alias analysis using the op provided.
28 
33  queue.push_back(value);
34  while (!queue.empty()) {
35  Value currentValue = queue.pop_back_val();
36  if (result.insert(currentValue).second) {
37  auto it = map.find(currentValue);
38  if (it != map.end()) {
39  for (Value aliasValue : it->second)
40  queue.push_back(aliasValue);
41  }
42  }
43  }
44  return result;
45 }
46 
47 /// Find all immediate and indirect dependent buffers this value could
48 /// potentially have. Note that the resulting set will also contain the value
49 /// provided as it is a dependent alias of itself.
52  return resolveValues(dependencies, rootValue);
53 }
54 
57  return resolveValues(reverseDependencies, rootValue);
58 }
59 
60 /// Removes the given values from all alias sets.
62  for (auto &entry : dependencies)
63  llvm::set_subtract(entry.second, aliasValues);
64 }
65 
67  dependencies[to] = dependencies[from];
68  dependencies.erase(from);
69 
70  for (auto &[_, value] : dependencies) {
71  if (value.contains(from)) {
72  value.insert(to);
73  value.erase(from);
74  }
75  }
76 }
77 
78 /// This function constructs a mapping from values to its immediate
79 /// dependencies. It iterates over all blocks, gets their predecessors,
80 /// determines the values that will be passed to the corresponding block
81 /// arguments and inserts them into the underlying map. Furthermore, it wires
82 /// successor regions and branch-like return operations from nested regions.
83 void BufferViewFlowAnalysis::build(Operation *op) {
84  // Registers all dependencies of the given values.
85  auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
86  for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
87  this->dependencies[value].insert(dep);
88  this->reverseDependencies[dep].insert(value);
89  }
90  };
91 
92  // Mark all buffer results and buffer region entry block arguments of the
93  // given op as terminals.
94  auto populateTerminalValues = [&](Operation *op) {
95  for (Value v : op->getResults())
96  if (isa<BaseMemRefType>(v.getType()))
97  this->terminals.insert(v);
98  for (Region &r : op->getRegions())
99  for (BlockArgument v : r.getArguments())
100  if (isa<BaseMemRefType>(v.getType()))
101  this->terminals.insert(v);
102  };
103 
104  op->walk([&](Operation *op) {
105  // Query BufferViewFlowOpInterface. If the op does not implement that
106  // interface, try to infer the dependencies from other interfaces that the
107  // op may implement.
108  if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
109  bufferViewFlowOp.populateDependencies(registerDependencies);
110  for (Value v : op->getResults())
111  if (isa<BaseMemRefType>(v.getType()) &&
112  bufferViewFlowOp.mayBeTerminalBuffer(v))
113  this->terminals.insert(v);
114  for (Region &r : op->getRegions())
115  for (BlockArgument v : r.getArguments())
116  if (isa<BaseMemRefType>(v.getType()) &&
117  bufferViewFlowOp.mayBeTerminalBuffer(v))
118  this->terminals.insert(v);
119  return WalkResult::advance();
120  }
121 
122  // Add additional dependencies created by view changes to the alias list.
123  if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
124  registerDependencies(viewInterface.getViewSource(),
125  viewInterface->getResult(0));
126  return WalkResult::advance();
127  }
128 
129  if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
130  // Query all branch interfaces to link block argument dependencies.
131  Block *parentBlock = branchInterface->getBlock();
132  for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
133  it != e; ++it) {
134  // Query the branch op interface to get the successor operands.
135  auto successorOperands =
136  branchInterface.getSuccessorOperands(it.getIndex());
137  // Build the actual mapping of values to their immediate dependencies.
138  registerDependencies(successorOperands.getForwardedOperands(),
139  (*it)->getArguments().drop_front(
140  successorOperands.getProducedOperandCount()));
141  }
142  return WalkResult::advance();
143  }
144 
145  if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
146  // Query the RegionBranchOpInterface to find potential successor regions.
147  // Extract all entry regions and wire all initial entry successor inputs.
148  SmallVector<RegionSuccessor, 2> entrySuccessors;
149  regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
150  entrySuccessors);
151  for (RegionSuccessor &entrySuccessor : entrySuccessors) {
152  // Wire the entry region's successor arguments with the initial
153  // successor inputs.
154  registerDependencies(
155  regionInterface.getEntrySuccessorOperands(entrySuccessor),
156  entrySuccessor.getSuccessorInputs());
157  }
158 
159  // Wire flow between regions and from region exits.
160  for (Region &region : regionInterface->getRegions()) {
161  // Iterate over all successor region entries that are reachable from the
162  // current region.
163  SmallVector<RegionSuccessor, 2> successorRegions;
164  regionInterface.getSuccessorRegions(region, successorRegions);
165  for (RegionSuccessor &successorRegion : successorRegions) {
166  // Iterate over all immediate terminator operations and wire the
167  // successor inputs with the successor operands of each terminator.
168  for (Block &block : region)
169  if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
170  block.getTerminator()))
171  registerDependencies(
172  terminator.getSuccessorOperands(successorRegion),
173  successorRegion.getSuccessorInputs());
174  }
175  }
176 
177  return WalkResult::advance();
178  }
179 
180  // Region terminators are handled together with RegionBranchOpInterface.
181  if (isa<RegionBranchTerminatorOpInterface>(op))
182  return WalkResult::advance();
183 
184  if (isa<CallOpInterface>(op)) {
185  // This is an intra-function analysis. We have no information about other
186  // functions. Conservatively assume that each operand may alias with each
187  // result. Also mark the results are terminals because the function could
188  // return newly allocated buffers.
189  populateTerminalValues(op);
190  for (Value operand : op->getOperands())
191  for (Value result : op->getResults())
192  registerDependencies({operand}, {result});
193  return WalkResult::advance();
194  }
195 
196  // We have no information about unknown ops.
197  populateTerminalValues(op);
198 
199  return WalkResult::advance();
200  });
201 }
202 
204  assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
205  return terminals.contains(value);
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // BufferOriginAnalysis
210 //===----------------------------------------------------------------------===//
211 
212 /// Return "true" if the given value is the result of a memory allocation.
213 static bool hasAllocateSideEffect(Value v) {
214  Operation *op = v.getDefiningOp();
215  if (!op)
216  return false;
217  return hasEffect<MemoryEffects::Allocate>(op, v);
218 }
219 
220 /// Return "true" if the given value is a function block argument.
221 static bool isFunctionArgument(Value v) {
222  auto bbArg = dyn_cast<BlockArgument>(v);
223  if (!bbArg)
224  return false;
225  Block *b = bbArg.getOwner();
226  auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
227  if (!funcOp)
228  return false;
229  return bbArg.getOwner() == &funcOp.getFunctionBody().front();
230 }
231 
232 /// Given a memref value, return the "base" value by skipping over all
233 /// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
234 static Value getViewBase(Value value) {
235  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
236  value = viewLikeOp.getViewSource();
237  return value;
238 }
239 
241 
243  assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
244  assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
245 
246  // Skip over all view-like ops.
247  v1 = getViewBase(v1);
248  v2 = getViewBase(v2);
249 
250  // Fast path: If both buffers are the same SSA value, we can be sure that
251  // they originate from the same allocation.
252  if (v1 == v2)
253  return true;
254 
255  // Compute the SSA values from which the buffers `v1` and `v2` originate.
256  SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
257  SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
258 
259  // Originating buffers are "terminal" if they could not be traced back any
260  // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
261  // - function block arguments
262  // - values defined by allocation ops such as "memref.alloc"
263  // - values defined by ops that are unknown to the buffer view flow analysis
264  // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
265  SmallPtrSet<Value, 16> terminal1, terminal2;
266 
267  // While gathering terminal buffers, keep track of whether all terminal
268  // buffers are newly allocated buffer or function entry arguments.
269  bool allAllocs1 = true, allAllocs2 = true;
270  bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
271 
272  // Helper function that gathers terminal buffers among `origin`.
273  auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
274  SmallPtrSet<Value, 16> &terminal,
275  bool &allAllocs,
276  bool &allAllocsOrFuncEntryArgs) {
277  for (Value v : origin) {
278  if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
279  terminal.insert(v);
280  allAllocs &= hasAllocateSideEffect(v);
281  allAllocsOrFuncEntryArgs &=
283  }
284  }
285  assert(!terminal.empty() && "expected non-empty terminal set");
286  };
287 
288  // Gather terminal buffers for `v1` and `v2`.
289  gatherTerminalBuffers(origin1, terminal1, allAllocs1,
290  allAllocsOrFuncEntryArgs1);
291  gatherTerminalBuffers(origin2, terminal2, allAllocs2,
292  allAllocsOrFuncEntryArgs2);
293 
294  // If both `v1` and `v2` have a single matching terminal buffer, they are
295  // guaranteed to originate from the same buffer allocation.
296  if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
297  *terminal1.begin() == *terminal2.begin())
298  return true;
299 
300  // At least one of the two values has multiple terminals.
301 
302  // Check if there is overlap between the terminal buffers of `v1` and `v2`.
303  bool distinctTerminalSets = true;
304  for (Value v : terminal1)
305  distinctTerminalSets &= !terminal2.contains(v);
306  // If there is overlap between the terminal buffers of `v1` and `v2`, we
307  // cannot make an accurate decision without further analysis.
308  if (!distinctTerminalSets)
309  return std::nullopt;
310 
311  // If `v1` originates from only allocs, and `v2` is guaranteed to originate
312  // from different allocations (that is guaranteed if `v2` originates from
313  // only distinct allocs or function entry arguments), we can be sure that
314  // `v1` and `v2` originate from different allocations. The same argument can
315  // be made when swapping `v1` and `v2`.
316  bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
317  bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
318  if (isolatedAlloc1 || isolatedAlloc2)
319  return false;
320 
321  // Otherwise: We do not know whether `v1` and `v2` originate from the same
322  // allocation or not.
323  // TODO: Function arguments are currently handled conservatively. We assume
324  // that they could be the same allocation.
325  // TODO: Terminals other than allocations and function arguments are
326  // currently handled conservatively. We assume that they could be the same
327  // allocation. E.g., we currently return "nullopt" for values that originate
328  // from different "memref.get_global" ops (with different symbols).
329  return std::nullopt;
330 }
static bool isFunctionArgument(Value v)
Return "true" if the given value is a function block argument.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static BufferViewFlowAnalysis::ValueSetT resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value)
static bool hasAllocateSideEffect(Value v)
Return "true" if the given value is the result of a memory allocation.
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
succ_iterator succ_end()
Definition: Block.h:264
succ_iterator succ_begin()
Definition: Block.h:263
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
std::optional< bool > isSameAllocation(Value v1, Value v2)
Return "true" if v1 and v2 originate from the same buffer allocation.
BufferViewFlowAnalysis(Operation *op)
Constructs a new alias analysis using the op provided.
void remove(const SetVector< Value > &aliasValues)
Removes the given values from all alias sets.
ValueSetT resolve(Value value) const
Find all immediate and indirect views upon this value.
void rename(Value from, Value to)
Replaces all occurrences of 'from' in the internal datastructures with 'to'.
bool mayBeTerminalBuffer(Value value) const
Returns "true" if the given value may be a terminal.
ValueSetT resolveReverse(Value value) const
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
Include the generated interface declarations.