16#include "llvm/ADT/SetOperations.h"
32 queue.push_back(value);
33 while (!queue.empty()) {
34 Value currentValue = queue.pop_back_val();
35 if (
result.insert(currentValue).second) {
36 auto it = map.find(currentValue);
37 if (it != map.end()) {
38 for (
Value aliasValue : it->second)
39 queue.push_back(aliasValue);
61 for (
auto &entry : dependencies)
62 llvm::set_subtract(entry.second, aliasValues);
66 dependencies[to] = dependencies[from];
67 dependencies.erase(from);
69 for (
auto &[_, value] : dependencies) {
70 if (value.contains(from)) {
82void BufferViewFlowAnalysis::build(
Operation *op) {
85 for (
auto [value, dep] : llvm::zip_equal(values, dependencies)) {
86 this->dependencies[value].insert(dep);
87 this->reverseDependencies[dep].insert(value);
93 auto populateTerminalValues = [&](Operation *op) {
95 if (isa<BaseMemRefType>(v.getType()))
96 this->terminals.insert(v);
98 for (BlockArgument v : r.getArguments())
99 if (isa<BaseMemRefType>(v.getType()))
100 this->terminals.insert(v);
103 op->
walk([&](Operation *op) {
107 if (
auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
108 bufferViewFlowOp.populateDependencies(registerDependencies);
110 if (isa<BaseMemRefType>(v.getType()) &&
111 bufferViewFlowOp.mayBeTerminalBuffer(v))
112 this->terminals.insert(v);
114 for (BlockArgument v : r.getArguments())
115 if (isa<BaseMemRefType>(v.getType()) &&
116 bufferViewFlowOp.mayBeTerminalBuffer(v))
117 this->terminals.insert(v);
122 if (
auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
123 registerDependencies(viewInterface.getViewSource(),
124 viewInterface.getViewDest());
128 if (
auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
130 Block *parentBlock = branchInterface->getBlock();
134 auto successorOperands =
135 branchInterface.getSuccessorOperands(it.getIndex());
137 registerDependencies(successorOperands.getForwardedOperands(),
138 (*it)->getArguments().drop_front(
139 successorOperands.getProducedOperandCount()));
144 if (
auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
147 SmallVector<RegionSuccessor, 2> entrySuccessors;
150 for (RegionSuccessor &entrySuccessor : entrySuccessors) {
153 registerDependencies(
154 regionInterface.getEntrySuccessorOperands(entrySuccessor),
155 entrySuccessor.getSuccessorInputs());
159 for (Region ®ion : regionInterface->getRegions()) {
162 SmallVector<RegionSuccessor, 2> successorRegions;
163 regionInterface.getSuccessorRegions(region, successorRegions);
164 for (RegionSuccessor &successorRegion : successorRegions) {
167 for (
Block &block : region)
168 if (
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
169 block.getTerminator()))
170 registerDependencies(
171 terminator.getSuccessorOperands(successorRegion),
172 successorRegion.getSuccessorInputs());
180 if (isa<RegionBranchTerminatorOpInterface>(op))
183 if (isa<CallOpInterface>(op)) {
188 populateTerminalValues(op);
191 registerDependencies({operand}, {
result});
196 populateTerminalValues(op);
203 assert(isa<BaseMemRefType>(value.
getType()) &&
"expected memref");
204 return terminals.contains(value);
221 auto bbArg = dyn_cast<BlockArgument>(v);
224 Block *
b = bbArg.getOwner();
225 auto funcOp = dyn_cast<FunctionOpInterface>(
b->getParentOp());
228 return bbArg.getOwner() == &funcOp.getFunctionBody().front();
234 while (
auto viewLikeOp = value.
getDefiningOp<ViewLikeOpInterface>()) {
235 if (value != viewLikeOp.getViewDest()) {
238 value = viewLikeOp.getViewSource();
246 assert(isa<BaseMemRefType>(v1.
getType()) &&
"expected buffer");
247 assert(isa<BaseMemRefType>(v2.
getType()) &&
"expected buffer");
272 bool allAllocs1 =
true, allAllocs2 =
true;
273 bool allAllocsOrFuncEntryArgs1 =
true, allAllocsOrFuncEntryArgs2 =
true;
279 bool &allAllocsOrFuncEntryArgs) {
280 for (
Value v : origin) {
281 if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
284 allAllocsOrFuncEntryArgs &=
288 assert(!terminal.empty() &&
"expected non-empty terminal set");
292 gatherTerminalBuffers(origin1, terminal1, allAllocs1,
293 allAllocsOrFuncEntryArgs1);
294 gatherTerminalBuffers(origin2, terminal2, allAllocs2,
295 allAllocsOrFuncEntryArgs2);
299 if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
300 *terminal1.begin() == *terminal2.begin())
306 bool distinctTerminalSets =
true;
307 for (
Value v : terminal1)
308 distinctTerminalSets &= !terminal2.contains(v);
311 if (!distinctTerminalSets)
319 bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
320 bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
321 if (isolatedAlloc1 || isolatedAlloc2)
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static bool isFunctionArgument(Value v)
Return "true" if the given value is a function block argument.
static BufferViewFlowAnalysis::ValueSetT resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value)
static bool hasAllocateSideEffect(Value v)
Return "true" if the given value is the result of a memory allocation.
template bool mlir::hasEffect< MemoryEffects::Allocate >(Operation *)
Block represents an ordered list of Operations.
succ_iterator succ_begin()
BufferOriginAnalysis(Operation *op)
std::optional< bool > isSameAllocation(Value v1, Value v2)
Return "true" if v1 and v2 originate from the same buffer allocation.
SmallPtrSet< Value, 16 > ValueSetT
BufferViewFlowAnalysis(Operation *op)
Constructs a new alias analysis using the op provided.
void remove(const SetVector< Value > &aliasValues)
Removes the given values from all alias sets.
ValueSetT resolve(Value value) const
Find all immediate and indirect views upon this value.
llvm::DenseMap< Value, ValueSetT > ValueMapT
void rename(Value from, Value to)
Replaces all occurrences of 'from' in the internal datastructures with 'to'.
bool mayBeTerminalBuffer(Value value) const
Returns "true" if the given value may be a terminal.
ValueSetT resolveReverse(Value value) const
Operation is the basic unit of execution within MLIR.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector