MLIR 22.0.0git
SliceAnalysis.cpp
Go to the documentation of this file.
1//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Analysis functions specific to slicing in Function.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Block.h"
16#include "mlir/IR/Operation.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SetVector.h"
21
22///
23/// Implements Analysis functions specific to slicing in Function.
24///
25
26using namespace mlir;
27
28static void
30 SetVector<Operation *> *forwardSlice,
31 const SliceOptions::TransitiveFilter &filter = nullptr) {
32 if (!op)
33 return;
34
35 // Evaluate whether we should keep this use.
36 // This is useful in particular to implement scoping; i.e. return the
37 // transitive forwardSlice in the current scope.
38 if (filter && !filter(op))
39 return;
40
41 for (Region &region : op->getRegions())
42 for (Block &block : region)
43 for (Operation &blockOp : block)
44 if (forwardSlice->count(&blockOp) == 0) {
45 // We don't have to check if the 'blockOp' is already visited because
46 // there cannot be a traversal path from this nested op to the parent
47 // and thus a cycle cannot be closed here. We still have to mark it
48 // as visited to stop before visiting this operation again if it is
49 // part of a cycle.
50 visited.insert(&blockOp);
51 getForwardSliceImpl(&blockOp, visited, forwardSlice, filter);
52 visited.erase(&blockOp);
53 }
54
55 for (Value result : op->getResults())
56 for (Operation *userOp : result.getUsers()) {
57 // A cycle can only occur within a basic block (not across regions or
58 // basic blocks) because the parent region must be a graph region, graph
59 // regions are restricted to always have 0 or 1 blocks, and there cannot
60 // be a def-use edge from a nested operation to an operation in an
61 // ancestor region. Therefore, we don't have to but may use the same
62 // 'visited' set across regions/blocks as long as we remove operations
63 // from the set again when the DFS traverses back from the leaf to the
64 // root.
65 if (forwardSlice->count(userOp) == 0 && visited.insert(userOp).second)
66 getForwardSliceImpl(userOp, visited, forwardSlice, filter);
67
68 visited.erase(userOp);
69 }
70
71 forwardSlice->insert(op);
72}
73
77 visited.insert(op);
78 getForwardSliceImpl(op, visited, forwardSlice, options.filter);
79 if (!options.inclusive) {
80 // Don't insert the top level operation, we just queried on it and don't
81 // want it in the results.
82 forwardSlice->remove(op);
83 }
84
85 // Reverse to get back the actual topological order.
86 // std::reverse does not work out of the box on SetVector and I want an
87 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
88 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
89 forwardSlice->insert(v.rbegin(), v.rend());
90}
91
93 const SliceOptions &options) {
95 for (Operation *user : root.getUsers()) {
96 visited.insert(user);
97 getForwardSliceImpl(user, visited, forwardSlice, options.filter);
98 visited.erase(user);
99 }
100
101 // Reverse to get back the actual topological order.
102 // std::reverse does not work out of the box on SetVector and I want an
103 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
104 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
105 forwardSlice->insert(v.rbegin(), v.rend());
106}
107
108static LogicalResult getBackwardSliceImpl(Operation *op,
109 DenseSet<Operation *> &visited,
110 SetVector<Operation *> *backwardSlice,
112 if (!op)
113 return success();
114
115 // Evaluate whether we should keep this def.
116 // This is useful in particular to implement scoping; i.e. return the
117 // transitive backwardSlice in the current scope.
118 if (options.filter && !options.filter(op))
119 return success();
120
121 auto processValue = [&](Value value) {
122 if (auto *definingOp = value.getDefiningOp()) {
123 if (backwardSlice->count(definingOp) == 0 &&
124 visited.insert(definingOp).second)
125 return getBackwardSliceImpl(definingOp, visited, backwardSlice,
126 options);
127
128 visited.erase(definingOp);
129 } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
130 if (options.omitBlockArguments)
131 return success();
132
133 Block *block = blockArg.getOwner();
134 Operation *parentOp = block->getParentOp();
135 // TODO: determine whether we want to recurse backward into the other
136 // blocks of parentOp, which are not technically backward unless they flow
137 // into us. For now, just bail.
138 if (parentOp && backwardSlice->count(parentOp) == 0) {
139 if (!parentOp->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
140 parentOp->getNumRegions() == 1 &&
141 parentOp->getRegion(0).hasOneBlock()) {
142 return getBackwardSliceImpl(parentOp, visited, backwardSlice,
143 options);
144 }
145 }
146 } else {
147 return failure();
148 }
149 return success();
150 };
151
152 bool succeeded = true;
153
154 if (!options.omitUsesFromAbove &&
156 llvm::for_each(op->getRegions(), [&](Region &region) {
157 // Walk this region recursively to collect the regions that descend from
158 // this op's nested regions (inclusive).
159 SmallPtrSet<Region *, 4> descendents;
160 region.walk(
161 [&](Region *childRegion) { descendents.insert(childRegion); });
162 region.walk([&](Operation *op) {
163 for (OpOperand &operand : op->getOpOperands()) {
164 if (!descendents.contains(operand.get().getParentRegion()))
165 if (!processValue(operand.get()).succeeded()) {
166 return WalkResult::interrupt();
167 }
168 }
169 return WalkResult::advance();
170 });
171 });
172 }
173 llvm::for_each(op->getOperands(), processValue);
174
175 backwardSlice->insert(op);
176 return success(succeeded);
177}
178
180 SetVector<Operation *> *backwardSlice,
182 DenseSet<Operation *> visited;
183 visited.insert(op);
184 LogicalResult result =
185 getBackwardSliceImpl(op, visited, backwardSlice, options);
186
187 if (!options.inclusive) {
188 // Don't insert the top level operation, we just queried on it and don't
189 // want it in the results.
190 backwardSlice->remove(op);
191 }
192 return result;
193}
194
195LogicalResult mlir::getBackwardSlice(Value root,
196 SetVector<Operation *> *backwardSlice,
198 if (Operation *definingOp = root.getDefiningOp()) {
199 return getBackwardSlice(definingOp, backwardSlice, options);
200 }
201 Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
202 return getBackwardSlice(bbAargOwner, backwardSlice, options);
203}
204
206mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
207 const ForwardSliceOptions &forwardSliceOptions) {
209 slice.insert(op);
210
211 unsigned currentIndex = 0;
212 SetVector<Operation *> backwardSlice;
213 SetVector<Operation *> forwardSlice;
214 while (currentIndex != slice.size()) {
215 auto *currentOp = (slice)[currentIndex];
216 // Compute and insert the backwardSlice starting from currentOp.
217 backwardSlice.clear();
218 LogicalResult result =
219 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
220 assert(result.succeeded());
221 (void)result;
222 slice.insert_range(backwardSlice);
223
224 // Compute and insert the forwardSlice starting from currentOp.
225 forwardSlice.clear();
226 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
227 slice.insert_range(forwardSlice);
228 ++currentIndex;
229 }
230 return topologicalSort(slice);
231}
232
233/// Returns true if `value` (transitively) depends on iteration-carried values
234/// of the given `ancestorOp`.
235static bool dependsOnCarriedVals(Value value,
236 ArrayRef<BlockArgument> iterCarriedArgs,
237 Operation *ancestorOp) {
238 // Compute the backward slice of the value.
240 BackwardSliceOptions sliceOptions;
241 sliceOptions.filter = [&](Operation *op) {
242 return !ancestorOp->isAncestor(op);
243 };
244 LogicalResult result = getBackwardSlice(value, &slice, sliceOptions);
245 assert(result.succeeded());
246 (void)result;
247
248 // Check that none of the operands of the operations in the backward slice are
249 // loop iteration arguments, and neither is the value itself.
250 SmallPtrSet<Value, 8> iterCarriedValSet(llvm::from_range, iterCarriedArgs);
251 if (iterCarriedValSet.contains(value))
252 return true;
253
254 for (Operation *op : slice)
255 for (Value operand : op->getOperands())
256 if (iterCarriedValSet.contains(operand))
257 return true;
258
259 return false;
260}
261
262/// Utility to match a generic reduction given a list of iteration-carried
263/// arguments, `iterCarriedArgs` and the position of the potential reduction
264/// argument within the list, `redPos`. If a reduction is matched, returns the
265/// reduced value and the topologically-sorted list of combiner operations
266/// involved in the reduction. Otherwise, returns a null value.
267///
268/// The matching algorithm relies on the following invariants, which are subject
269/// to change:
270/// 1. The first combiner operation must be a binary operation with the
271/// iteration-carried value and the reduced value as operands.
272/// 2. The iteration-carried value and combiner operations must be side
273/// effect-free, have single result and a single use.
274/// 3. Combiner operations must be immediately nested in the region op
275/// performing the reduction.
276/// 4. Reduction def-use chain must end in a terminator op that yields the
277/// next iteration/output values in the same order as the iteration-carried
278/// values in `iterCarriedArgs`.
279/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
280/// of the region op performing the reduction.
281///
282/// This utility is generic enough to detect reductions involving multiple
283/// combiner operations (disabled for now) across multiple dialects, including
284/// Linalg, Affine and SCF. For the sake of genericity, it does not return
285/// specific enum values for the combiner operations since its goal is also
286/// matching reductions without pre-defined semantics in core MLIR. It's up to
287/// each client to make sense out of the list of combiner operations. It's also
288/// up to each client to check for additional invariants on the expected
289/// reductions not covered by this generic matching.
291 unsigned redPos,
292 SmallVectorImpl<Operation *> &combinerOps) {
293 assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
294
295 BlockArgument redCarriedVal = iterCarriedArgs[redPos];
296 if (!redCarriedVal.hasOneUse())
297 return nullptr;
298
299 // For now, the first combiner op must be a binary op.
300 Operation *combinerOp = *redCarriedVal.getUsers().begin();
301 if (combinerOp->getNumOperands() != 2)
302 return nullptr;
303 Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
304 ? combinerOp->getOperand(1)
305 : combinerOp->getOperand(0);
306
307 Operation *redRegionOp =
308 iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
309 if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
310 return nullptr;
311
312 // Traverse the def-use chain starting from the first combiner op until a
313 // terminator is found. Gather all the combiner ops along the way in
314 // topological order.
315 while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
316 if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
317 !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
318 return nullptr;
319
320 combinerOps.push_back(combinerOp);
321 combinerOp = *combinerOp->getUsers().begin();
322 }
323
324 // Limit matching to single combiner op until we can properly test reductions
325 // involving multiple combiners.
326 if (combinerOps.size() != 1)
327 return nullptr;
328
329 // Check that the yielded value is in the same position as in
330 // `iterCarriedArgs`.
331 Operation *terminatorOp = combinerOp;
332 if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
333 return nullptr;
334
335 return reducedVal;
336}
return success()
static llvm::ManagedStatic< PassManagerOptions > options
static void processValue(Value value, LiveMap &liveMap)
static LogicalResult getBackwardSliceImpl(Operation *op, DenseSet< Operation * > &visited, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options)
static bool dependsOnCarriedVals(Value value, ArrayRef< BlockArgument > iterCarriedArgs, Operation *ancestorOp)
Returns true if value (transitively) depends on iteration-carried values of the given ancestorOp.
static void getForwardSliceImpl(Operation *op, DenseSet< Operation * > &visited, SetVector< Operation * > *forwardSlice, const SliceOptions::TransitiveFilter &filter=nullptr)
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class represents an operand of an operation.
Definition Value.h:257
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:757
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Include the generated interface declarations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
SliceOptions ForwardSliceOptions
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
std::function< bool(Operation *)> TransitiveFilter
Type of the condition to limit the propagation of transitive use-defs.
TransitiveFilter filter