MLIR 23.0.0git
SliceAnalysis.cpp
Go to the documentation of this file.
1//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Analysis functions specific to slicing in Function.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Block.h"
16#include "mlir/IR/Operation.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SetVector.h"
21
22///
23/// Implements Analysis functions specific to slicing in Function.
24///
25
26using namespace mlir;
27
28static void
30 SetVector<Operation *> *forwardSlice,
31 const SliceOptions::TransitiveFilter &filter = nullptr) {
32 if (!op)
33 return;
34
35 // Evaluate whether we should keep this use.
36 // This is useful in particular to implement scoping; i.e. return the
37 // transitive forwardSlice in the current scope.
38 if (filter && !filter(op))
39 return;
40
41 for (Region &region : op->getRegions())
42 for (Block &block : region)
43 for (Operation &blockOp : block)
44 if (forwardSlice->count(&blockOp) == 0) {
45 // We don't have to check if the 'blockOp' is already visited because
46 // there cannot be a traversal path from this nested op to the parent
47 // and thus a cycle cannot be closed here. We still have to mark it
48 // as visited to stop before visiting this operation again if it is
49 // part of a cycle.
50 visited.insert(&blockOp);
51 getForwardSliceImpl(&blockOp, visited, forwardSlice, filter);
52 visited.erase(&blockOp);
53 }
54
55 for (Value result : op->getResults())
56 for (Operation *userOp : result.getUsers()) {
57 // A cycle can only occur within a basic block (not across regions or
58 // basic blocks) because the parent region must be a graph region, graph
59 // regions are restricted to always have 0 or 1 blocks, and there cannot
60 // be a def-use edge from a nested operation to an operation in an
61 // ancestor region. Therefore, we don't have to but may use the same
62 // 'visited' set across regions/blocks as long as we remove operations
63 // from the set again when the DFS traverses back from the leaf to the
64 // root.
65 if (forwardSlice->count(userOp) == 0 && visited.insert(userOp).second) {
66 getForwardSliceImpl(userOp, visited, forwardSlice, filter);
67 visited.erase(userOp);
68 }
69 }
70
71 forwardSlice->insert(op);
72}
73
77 visited.insert(op);
78 getForwardSliceImpl(op, visited, forwardSlice, options.filter);
79 if (!options.inclusive) {
80 // Don't insert the top level operation, we just queried on it and don't
81 // want it in the results.
82 forwardSlice->remove(op);
83 }
84 assert(visited.size() == 1 && "visited set should only contain op");
85
86 // Reverse to get back the actual topological order.
87 // std::reverse does not work out of the box on SetVector and I want an
88 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
89 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
90 forwardSlice->insert(v.rbegin(), v.rend());
91}
92
94 const SliceOptions &options) {
96 for (Operation *user : root.getUsers()) {
97 visited.insert(user);
98 getForwardSliceImpl(user, visited, forwardSlice, options.filter);
99 visited.erase(user);
100 }
101 assert(visited.empty() && "visited set should be empty");
102
103 // Reverse to get back the actual topological order.
104 // std::reverse does not work out of the box on SetVector and I want an
105 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
106 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
107 forwardSlice->insert(v.rbegin(), v.rend());
108}
109
110static LogicalResult getBackwardSliceImpl(Operation *op,
111 DenseSet<Operation *> &visited,
112 SetVector<Operation *> *backwardSlice,
114 if (!op)
115 return success();
116
117 // Evaluate whether we should keep this def.
118 // This is useful in particular to implement scoping; i.e. return the
119 // transitive backwardSlice in the current scope.
120 if (options.filter && !options.filter(op))
121 return success();
122
123 auto processValue = [&](Value value) {
124 if (auto *definingOp = value.getDefiningOp()) {
125 if (backwardSlice->count(definingOp) == 0 &&
126 visited.insert(definingOp).second) {
127 LogicalResult result =
128 getBackwardSliceImpl(definingOp, visited, backwardSlice, options);
129 visited.erase(definingOp);
130 return result;
131 }
132 } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
133 if (options.omitBlockArguments)
134 return success();
135
136 Block *block = blockArg.getOwner();
137 Operation *parentOp = block->getParentOp();
138 // TODO: determine whether we want to recurse backward into the other
139 // blocks of parentOp, which are not technically backward unless they flow
140 // into us. For now, just bail.
141 if (parentOp && backwardSlice->count(parentOp) == 0) {
142 if (!parentOp->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
143 parentOp->getNumRegions() == 1 &&
144 parentOp->getRegion(0).hasOneBlock()) {
145 return getBackwardSliceImpl(parentOp, visited, backwardSlice,
146 options);
147 }
148 }
149 } else {
150 return failure();
151 }
152 return success();
153 };
154
155 bool succeeded = true;
156
157 if (!options.omitUsesFromAbove &&
159 llvm::for_each(op->getRegions(), [&](Region &region) {
160 // Walk this region recursively to collect the regions that descend from
161 // this op's nested regions (inclusive).
162 SmallPtrSet<Region *, 4> descendents;
163 region.walk(
164 [&](Region *childRegion) { descendents.insert(childRegion); });
165 region.walk([&](Operation *op) {
166 for (OpOperand &operand : op->getOpOperands()) {
167 if (!descendents.contains(operand.get().getParentRegion()))
168 if (!processValue(operand.get()).succeeded()) {
169 return WalkResult::interrupt();
170 }
171 }
172 return WalkResult::advance();
173 });
174 });
175 }
176 llvm::for_each(op->getOperands(), processValue);
177
178 backwardSlice->insert(op);
179 return success(succeeded);
180}
181
183 SetVector<Operation *> *backwardSlice,
185 DenseSet<Operation *> visited;
186 visited.insert(op);
187 LogicalResult result =
188 getBackwardSliceImpl(op, visited, backwardSlice, options);
189 assert(visited.size() == 1 && "visited set should only contain op");
190
191 if (!options.inclusive) {
192 // Don't insert the top level operation, we just queried on it and don't
193 // want it in the results.
194 backwardSlice->remove(op);
195 }
196 return result;
197}
198
199LogicalResult mlir::getBackwardSlice(Value root,
200 SetVector<Operation *> *backwardSlice,
202 if (Operation *definingOp = root.getDefiningOp()) {
203 return getBackwardSlice(definingOp, backwardSlice, options);
204 }
205 Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
206 return getBackwardSlice(bbAargOwner, backwardSlice, options);
207}
208
210mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
211 const ForwardSliceOptions &forwardSliceOptions) {
213 slice.insert(op);
214
215 unsigned currentIndex = 0;
216 SetVector<Operation *> backwardSlice;
217 SetVector<Operation *> forwardSlice;
218 while (currentIndex != slice.size()) {
219 auto *currentOp = (slice)[currentIndex];
220 // Compute and insert the backwardSlice starting from currentOp.
221 backwardSlice.clear();
222 LogicalResult result =
223 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
224 assert(result.succeeded());
225 (void)result;
226 slice.insert_range(backwardSlice);
227
228 // Compute and insert the forwardSlice starting from currentOp.
229 forwardSlice.clear();
230 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
231 slice.insert_range(forwardSlice);
232 ++currentIndex;
233 }
234 return topologicalSort(slice);
235}
236
237/// Returns true if `value` (transitively) depends on iteration-carried values
238/// of the given `ancestorOp`.
239static bool dependsOnCarriedVals(Value value,
240 ArrayRef<BlockArgument> iterCarriedArgs,
241 Operation *ancestorOp) {
242 // Compute the backward slice of the value.
244 BackwardSliceOptions sliceOptions;
245 sliceOptions.filter = [&](Operation *op) {
246 return !ancestorOp->isAncestor(op);
247 };
248 LogicalResult result = getBackwardSlice(value, &slice, sliceOptions);
249 assert(result.succeeded());
250 (void)result;
251
252 // Check that none of the operands of the operations in the backward slice are
253 // loop iteration arguments, and neither is the value itself.
254 SmallPtrSet<Value, 8> iterCarriedValSet(llvm::from_range, iterCarriedArgs);
255 if (iterCarriedValSet.contains(value))
256 return true;
257
258 for (Operation *op : slice)
259 for (Value operand : op->getOperands())
260 if (iterCarriedValSet.contains(operand))
261 return true;
262
263 return false;
264}
265
266/// Utility to match a generic reduction given a list of iteration-carried
267/// arguments, `iterCarriedArgs` and the position of the potential reduction
268/// argument within the list, `redPos`. If a reduction is matched, returns the
269/// reduced value and the topologically-sorted list of combiner operations
270/// involved in the reduction. Otherwise, returns a null value.
271///
272/// The matching algorithm relies on the following invariants, which are subject
273/// to change:
274/// 1. The first combiner operation must be a binary operation with the
275/// iteration-carried value and the reduced value as operands.
276/// 2. The iteration-carried value and combiner operations must be side
277/// effect-free, have single result and a single use.
278/// 3. Combiner operations must be immediately nested in the region op
279/// performing the reduction.
280/// 4. Reduction def-use chain must end in a terminator op that yields the
281/// next iteration/output values in the same order as the iteration-carried
282/// values in `iterCarriedArgs`.
283/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
284/// of the region op performing the reduction.
285///
286/// This utility is generic enough to detect reductions involving multiple
287/// combiner operations (disabled for now) across multiple dialects, including
288/// Linalg, Affine and SCF. For the sake of genericity, it does not return
289/// specific enum values for the combiner operations since its goal is also
290/// matching reductions without pre-defined semantics in core MLIR. It's up to
291/// each client to make sense out of the list of combiner operations. It's also
292/// up to each client to check for additional invariants on the expected
293/// reductions not covered by this generic matching.
295 unsigned redPos,
296 SmallVectorImpl<Operation *> &combinerOps) {
297 assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
298
299 BlockArgument redCarriedVal = iterCarriedArgs[redPos];
300 if (!redCarriedVal.hasOneUse())
301 return nullptr;
302
303 // For now, the first combiner op must be a binary op.
304 Operation *combinerOp = *redCarriedVal.getUsers().begin();
305 if (combinerOp->getNumOperands() != 2)
306 return nullptr;
307 Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
308 ? combinerOp->getOperand(1)
309 : combinerOp->getOperand(0);
310
311 Operation *redRegionOp =
312 iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
313 if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
314 return nullptr;
315
316 // Traverse the def-use chain starting from the first combiner op until a
317 // terminator is found. Gather all the combiner ops along the way in
318 // topological order.
319 while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
320 if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
321 !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
322 return nullptr;
323
324 combinerOps.push_back(combinerOp);
325 combinerOp = *combinerOp->getUsers().begin();
326 }
327
328 // Limit matching to single combiner op until we can properly test reductions
329 // involving multiple combiners.
330 if (combinerOps.size() != 1)
331 return nullptr;
332
333 // Check that the yielded value is in the same position as in
334 // `iterCarriedArgs`.
335 Operation *terminatorOp = combinerOp;
336 if (redPos >= terminatorOp->getNumOperands() ||
337 terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
338 return nullptr;
339
340 return reducedVal;
341}
return success()
static llvm::ManagedStatic< PassManagerOptions > options
static void processValue(Value value, LiveMap &liveMap)
static LogicalResult getBackwardSliceImpl(Operation *op, DenseSet< Operation * > &visited, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options)
static bool dependsOnCarriedVals(Value value, ArrayRef< BlockArgument > iterCarriedArgs, Operation *ancestorOp)
Returns true if value (transitively) depends on iteration-carried values of the given ancestorOp.
static void getForwardSliceImpl(Operation *op, DenseSet< Operation * > &visited, SetVector< Operation * > *forwardSlice, const SliceOptions::TransitiveFilter &filter=nullptr)
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class represents an operand of an operation.
Definition Value.h:254
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
Value getOperand(unsigned idx)
Definition Operation.h:375
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:782
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:874
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:408
unsigned getNumOperands()
Definition Operation.h:371
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:702
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:288
user_range getUsers()
Returns a range of all users.
Definition Operation.h:898
result_range getResults()
Definition Operation.h:440
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:296
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Include the generated interface declarations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
SliceOptions ForwardSliceOptions
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
std::function< bool(Operation *)> TransitiveFilter
Type of the condition to limit the propagation of transitive use-defs.
TransitiveFilter filter