MLIR 23.0.0git
BufferViewFlowAnalysis.cpp
Go to the documentation of this file.
1//======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "llvm/ADT/SetOperations.h"
17
18using namespace mlir;
19using namespace mlir::bufferization;
20
21//===----------------------------------------------------------------------===//
22// BufferViewFlowAnalysis
23//===----------------------------------------------------------------------===//
24
25/// Constructs a new alias analysis using the op provided.
27
32 queue.push_back(value);
33 while (!queue.empty()) {
34 Value currentValue = queue.pop_back_val();
35 if (result.insert(currentValue).second) {
36 auto it = map.find(currentValue);
37 if (it != map.end()) {
38 for (Value aliasValue : it->second)
39 queue.push_back(aliasValue);
40 }
41 }
42 }
43 return result;
44}
45
46/// Find all immediate and indirect dependent buffers this value could
47/// potentially have. Note that the resulting set will also contain the value
48/// provided as it is a dependent alias of itself.
51 return resolveValues(dependencies, rootValue);
52}
53
56 return resolveValues(reverseDependencies, rootValue);
57}
58
59/// Removes the given values from all alias sets.
61 for (auto &entry : dependencies)
62 llvm::set_subtract(entry.second, aliasValues);
63}
64
66 dependencies[to] = dependencies[from];
67 dependencies.erase(from);
68
69 for (auto &[_, value] : dependencies) {
70 if (value.contains(from)) {
71 value.insert(to);
72 value.erase(from);
73 }
74 }
75}
76
77/// This function constructs a mapping from values to its immediate
78/// dependencies. It iterates over all blocks, gets their predecessors,
79/// determines the values that will be passed to the corresponding block
80/// arguments and inserts them into the underlying map. Furthermore, it wires
81/// successor regions and branch-like return operations from nested regions.
82void BufferViewFlowAnalysis::build(Operation *op) {
83 // Registers all dependencies of the given values.
84 auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
85 for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
86 this->dependencies[value].insert(dep);
87 this->reverseDependencies[dep].insert(value);
88 }
89 };
90
91 // Mark all buffer results and buffer region entry block arguments of the
92 // given op as terminals.
93 auto populateTerminalValues = [&](Operation *op) {
94 for (Value v : op->getResults())
95 if (isa<BaseMemRefType>(v.getType()))
96 this->terminals.insert(v);
97 for (Region &r : op->getRegions())
98 for (BlockArgument v : r.getArguments())
99 if (isa<BaseMemRefType>(v.getType()))
100 this->terminals.insert(v);
101 };
102
103 op->walk([&](Operation *op) {
104 // Query BufferViewFlowOpInterface. If the op does not implement that
105 // interface, try to infer the dependencies from other interfaces that the
106 // op may implement.
107 if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
108 bufferViewFlowOp.populateDependencies(registerDependencies);
109 for (Value v : op->getResults())
110 if (isa<BaseMemRefType>(v.getType()) &&
111 bufferViewFlowOp.mayBeTerminalBuffer(v))
112 this->terminals.insert(v);
113 for (Region &r : op->getRegions())
114 for (BlockArgument v : r.getArguments())
115 if (isa<BaseMemRefType>(v.getType()) &&
116 bufferViewFlowOp.mayBeTerminalBuffer(v))
117 this->terminals.insert(v);
118 return WalkResult::advance();
119 }
120
121 // Add additional dependencies created by view changes to the alias list.
122 if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
123 registerDependencies(viewInterface.getViewSource(),
124 viewInterface.getViewDest());
125 return WalkResult::advance();
126 }
127
128 if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
129 // Query all branch interfaces to link block argument dependencies.
130 Block *parentBlock = branchInterface->getBlock();
131 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
132 it != e; ++it) {
133 // Query the branch op interface to get the successor operands.
134 auto successorOperands =
135 branchInterface.getSuccessorOperands(it.getIndex());
136 // Build the actual mapping of values to their immediate dependencies.
137 registerDependencies(successorOperands.getForwardedOperands(),
138 (*it)->getArguments().drop_front(
139 successorOperands.getProducedOperandCount()));
140 }
141 return WalkResult::advance();
142 }
143
144 if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
145 // Wire the successor operands with the successor inputs.
147 regionInterface.getSuccessorOperandInputMapping(mapping);
148 for (const auto &[operand, inputs] : mapping)
149 for (Value input : inputs)
150 registerDependencies({operand->get()}, {input});
151 return WalkResult::advance();
152 }
153
154 // Region terminators are handled together with RegionBranchOpInterface.
155 if (isa<RegionBranchTerminatorOpInterface>(op))
156 return WalkResult::advance();
157
158 if (isa<CallOpInterface>(op)) {
159 // This is an intra-function analysis. We have no information about other
160 // functions. Conservatively assume that each operand may alias with each
161 // result. Also mark the results are terminals because the function could
162 // return newly allocated buffers.
163 populateTerminalValues(op);
164 for (Value operand : op->getOperands())
165 for (Value result : op->getResults())
166 registerDependencies({operand}, {result});
167 return WalkResult::advance();
168 }
169
170 // We have no information about unknown ops.
171 populateTerminalValues(op);
172
173 return WalkResult::advance();
174 });
175}
176
178 assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
179 return terminals.contains(value);
180}
181
182//===----------------------------------------------------------------------===//
183// BufferOriginAnalysis
184//===----------------------------------------------------------------------===//
185
186/// Return "true" if the given value is the result of a memory allocation.
188 Operation *op = v.getDefiningOp();
189 if (!op)
190 return false;
192}
193
194/// Return "true" if the given value is a function block argument.
195static bool isFunctionArgument(Value v) {
196 auto bbArg = dyn_cast<BlockArgument>(v);
197 if (!bbArg)
198 return false;
199 Block *b = bbArg.getOwner();
200 auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
201 if (!funcOp)
202 return false;
203 return bbArg.getOwner() == &funcOp.getFunctionBody().front();
204}
205
206/// Given a memref value, return the "base" value by skipping over all
207/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
208static Value getViewBase(Value value) {
209 while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
210 if (value != viewLikeOp.getViewDest()) {
211 break;
212 }
213 value = viewLikeOp.getViewSource();
214 }
215 return value;
216}
217
219
221 assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
222 assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
223
224 // Skip over all view-like ops.
225 v1 = getViewBase(v1);
226 v2 = getViewBase(v2);
227
228 // Fast path: If both buffers are the same SSA value, we can be sure that
229 // they originate from the same allocation.
230 if (v1 == v2)
231 return true;
232
233 // Compute the SSA values from which the buffers `v1` and `v2` originate.
234 SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
235 SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
236
237 // Originating buffers are "terminal" if they could not be traced back any
238 // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
239 // - function block arguments
240 // - values defined by allocation ops such as "memref.alloc"
241 // - values defined by ops that are unknown to the buffer view flow analysis
242 // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
243 SmallPtrSet<Value, 16> terminal1, terminal2;
244
245 // While gathering terminal buffers, keep track of whether all terminal
246 // buffers are newly allocated buffer or function entry arguments.
247 bool allAllocs1 = true, allAllocs2 = true;
248 bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
249
250 // Helper function that gathers terminal buffers among `origin`.
251 auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
252 SmallPtrSet<Value, 16> &terminal,
253 bool &allAllocs,
254 bool &allAllocsOrFuncEntryArgs) {
255 for (Value v : origin) {
256 if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
257 terminal.insert(v);
258 allAllocs &= hasAllocateSideEffect(v);
259 allAllocsOrFuncEntryArgs &=
261 }
262 }
263 assert(!terminal.empty() && "expected non-empty terminal set");
264 };
265
266 // Gather terminal buffers for `v1` and `v2`.
267 gatherTerminalBuffers(origin1, terminal1, allAllocs1,
268 allAllocsOrFuncEntryArgs1);
269 gatherTerminalBuffers(origin2, terminal2, allAllocs2,
270 allAllocsOrFuncEntryArgs2);
271
272 // If both `v1` and `v2` have a single matching terminal buffer, they are
273 // guaranteed to originate from the same buffer allocation.
274 if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
275 *terminal1.begin() == *terminal2.begin())
276 return true;
277
278 // At least one of the two values has multiple terminals.
279
280 // Check if there is overlap between the terminal buffers of `v1` and `v2`.
281 bool distinctTerminalSets = true;
282 for (Value v : terminal1)
283 distinctTerminalSets &= !terminal2.contains(v);
284 // If there is overlap between the terminal buffers of `v1` and `v2`, we
285 // cannot make an accurate decision without further analysis.
286 if (!distinctTerminalSets)
287 return std::nullopt;
288
289 // If `v1` originates from only allocs, and `v2` is guaranteed to originate
290 // from different allocations (that is guaranteed if `v2` originates from
291 // only distinct allocs or function entry arguments), we can be sure that
292 // `v1` and `v2` originate from different allocations. The same argument can
293 // be made when swapping `v1` and `v2`.
294 bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
295 bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
296 if (isolatedAlloc1 || isolatedAlloc2)
297 return false;
298
299 // Otherwise: We do not know whether `v1` and `v2` originate from the same
300 // allocation or not.
301 // TODO: Function arguments are currently handled conservatively. We assume
302 // that they could be the same allocation.
303 // TODO: Terminals other than allocations and function arguments are
304 // currently handled conservatively. We assume that they could be the same
305 // allocation. E.g., we currently return "nullopt" for values that originate
306 // from different "memref.get_global" ops (with different symbols).
307 return std::nullopt;
308}
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static bool isFunctionArgument(Value v)
Return "true" if the given value is a function block argument.
static BufferViewFlowAnalysis::ValueSetT resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value)
static bool hasAllocateSideEffect(Value v)
Return "true" if the given value is the result of a memory allocation.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
template bool mlir::hasEffect< MemoryEffects::Allocate >(Operation *)
Block represents an ordered list of Operations.
Definition Block.h:33
succ_iterator succ_end()
Definition Block.h:279
succ_iterator succ_begin()
Definition Block.h:278
std::optional< bool > isSameAllocation(Value v1, Value v2)
Return "true" if v1 and v2 originate from the same buffer allocation.
SmallPtrSet< Value, 16 > ValueSetT
BufferViewFlowAnalysis(Operation *op)
Constructs a new alias analysis using the op provided.
void remove(const SetVector< Value > &aliasValues)
Removes the given values from all alias sets.
ValueSetT resolve(Value value) const
Find all immediate and indirect views upon this value.
llvm::DenseMap< Value, ValueSetT > ValueMapT
void rename(Value from, Value to)
Replaces all occurrences of 'from' in the internal datastructures with 'to'.
bool mayBeTerminalBuffer(Value value) const
Returns "true" if the given value may be a terminal.
ValueSetT resolveReverse(Value value) const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126