MLIR  22.0.0git
SliceAnalysis.cpp
Go to the documentation of this file.
1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Block.h"
16 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 
22 ///
23 /// Implements Analysis functions specific to slicing in Function.
24 ///
25 
26 using namespace mlir;
27 
28 static void
30  SetVector<Operation *> *forwardSlice,
31  const SliceOptions::TransitiveFilter &filter = nullptr) {
32  if (!op)
33  return;
34 
35  // Evaluate whether we should keep this use.
36  // This is useful in particular to implement scoping; i.e. return the
37  // transitive forwardSlice in the current scope.
38  if (filter && !filter(op))
39  return;
40 
41  for (Region &region : op->getRegions())
42  for (Block &block : region)
43  for (Operation &blockOp : block)
44  if (forwardSlice->count(&blockOp) == 0) {
45  // We don't have to check if the 'blockOp' is already visited because
46  // there cannot be a traversal path from this nested op to the parent
47  // and thus a cycle cannot be closed here. We still have to mark it
48  // as visited to stop before visiting this operation again if it is
49  // part of a cycle.
50  visited.insert(&blockOp);
51  getForwardSliceImpl(&blockOp, visited, forwardSlice, filter);
52  visited.erase(&blockOp);
53  }
54 
55  for (Value result : op->getResults())
56  for (Operation *userOp : result.getUsers()) {
57  // A cycle can only occur within a basic block (not across regions or
58  // basic blocks) because the parent region must be a graph region, graph
59  // regions are restricted to always have 0 or 1 blocks, and there cannot
60  // be a def-use edge from a nested operation to an operation in an
61  // ancestor region. Therefore, we don't have to but may use the same
62  // 'visited' set across regions/blocks as long as we remove operations
63  // from the set again when the DFS traverses back from the leaf to the
64  // root.
65  if (forwardSlice->count(userOp) == 0 && visited.insert(userOp).second)
66  getForwardSliceImpl(userOp, visited, forwardSlice, filter);
67 
68  visited.erase(userOp);
69  }
70 
71  forwardSlice->insert(op);
72 }
73 
76  DenseSet<Operation *> visited;
77  visited.insert(op);
78  getForwardSliceImpl(op, visited, forwardSlice, options.filter);
79  if (!options.inclusive) {
80  // Don't insert the top level operation, we just queried on it and don't
81  // want it in the results.
82  forwardSlice->remove(op);
83  }
84 
85  // Reverse to get back the actual topological order.
86  // std::reverse does not work out of the box on SetVector and I want an
87  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
88  SmallVector<Operation *, 0> v(forwardSlice->takeVector());
89  forwardSlice->insert(v.rbegin(), v.rend());
90 }
91 
93  const SliceOptions &options) {
94  DenseSet<Operation *> visited;
95  for (Operation *user : root.getUsers()) {
96  visited.insert(user);
97  getForwardSliceImpl(user, visited, forwardSlice, options.filter);
98  visited.erase(user);
99  }
100 
101  // Reverse to get back the actual topological order.
102  // std::reverse does not work out of the box on SetVector and I want an
103  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
104  SmallVector<Operation *, 0> v(forwardSlice->takeVector());
105  forwardSlice->insert(v.rbegin(), v.rend());
106 }
107 
108 static LogicalResult getBackwardSliceImpl(Operation *op,
109  DenseSet<Operation *> &visited,
110  SetVector<Operation *> *backwardSlice,
111  const BackwardSliceOptions &options) {
112  if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
113  return success();
114 
115  // Evaluate whether we should keep this def.
116  // This is useful in particular to implement scoping; i.e. return the
117  // transitive backwardSlice in the current scope.
118  if (options.filter && !options.filter(op))
119  return success();
120 
121  auto processValue = [&](Value value) {
122  if (auto *definingOp = value.getDefiningOp()) {
123  if (backwardSlice->count(definingOp) == 0 &&
124  visited.insert(definingOp).second)
125  return getBackwardSliceImpl(definingOp, visited, backwardSlice,
126  options);
127 
128  visited.erase(definingOp);
129  } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
130  if (options.omitBlockArguments)
131  return success();
132 
133  Block *block = blockArg.getOwner();
134  Operation *parentOp = block->getParentOp();
135  // TODO: determine whether we want to recurse backward into the other
136  // blocks of parentOp, which are not technically backward unless they flow
137  // into us. For now, just bail.
138  if (parentOp && backwardSlice->count(parentOp) == 0) {
139  if (parentOp->getNumRegions() == 1 &&
140  parentOp->getRegion(0).hasOneBlock()) {
141  return getBackwardSliceImpl(parentOp, visited, backwardSlice,
142  options);
143  }
144  }
145  } else {
146  return failure();
147  }
148  return success();
149  };
150 
151  bool succeeded = true;
152 
153  if (!options.omitUsesFromAbove) {
154  llvm::for_each(op->getRegions(), [&](Region &region) {
155  // Walk this region recursively to collect the regions that descend from
156  // this op's nested regions (inclusive).
157  SmallPtrSet<Region *, 4> descendents;
158  region.walk(
159  [&](Region *childRegion) { descendents.insert(childRegion); });
160  region.walk([&](Operation *op) {
161  for (OpOperand &operand : op->getOpOperands()) {
162  if (!descendents.contains(operand.get().getParentRegion()))
163  if (!processValue(operand.get()).succeeded()) {
164  return WalkResult::interrupt();
165  }
166  }
167  return WalkResult::advance();
168  });
169  });
170  }
171  llvm::for_each(op->getOperands(), processValue);
172 
173  backwardSlice->insert(op);
174  return success(succeeded);
175 }
176 
178  SetVector<Operation *> *backwardSlice,
179  const BackwardSliceOptions &options) {
180  DenseSet<Operation *> visited;
181  visited.insert(op);
182  LogicalResult result =
183  getBackwardSliceImpl(op, visited, backwardSlice, options);
184 
185  if (!options.inclusive) {
186  // Don't insert the top level operation, we just queried on it and don't
187  // want it in the results.
188  backwardSlice->remove(op);
189  }
190  return result;
191 }
192 
193 LogicalResult mlir::getBackwardSlice(Value root,
194  SetVector<Operation *> *backwardSlice,
195  const BackwardSliceOptions &options) {
196  if (Operation *definingOp = root.getDefiningOp()) {
197  return getBackwardSlice(definingOp, backwardSlice, options);
198  }
199  Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
200  return getBackwardSlice(bbAargOwner, backwardSlice, options);
201 }
202 
204 mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
205  const ForwardSliceOptions &forwardSliceOptions) {
207  slice.insert(op);
208 
209  unsigned currentIndex = 0;
210  SetVector<Operation *> backwardSlice;
211  SetVector<Operation *> forwardSlice;
212  while (currentIndex != slice.size()) {
213  auto *currentOp = (slice)[currentIndex];
214  // Compute and insert the backwardSlice starting from currentOp.
215  backwardSlice.clear();
216  LogicalResult result =
217  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
218  assert(result.succeeded());
219  (void)result;
220  slice.insert_range(backwardSlice);
221 
222  // Compute and insert the forwardSlice starting from currentOp.
223  forwardSlice.clear();
224  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
225  slice.insert_range(forwardSlice);
226  ++currentIndex;
227  }
228  return topologicalSort(slice);
229 }
230 
231 /// Returns true if `value` (transitively) depends on iteration-carried values
232 /// of the given `ancestorOp`.
233 static bool dependsOnCarriedVals(Value value,
234  ArrayRef<BlockArgument> iterCarriedArgs,
235  Operation *ancestorOp) {
236  // Compute the backward slice of the value.
238  BackwardSliceOptions sliceOptions;
239  sliceOptions.filter = [&](Operation *op) {
240  return !ancestorOp->isAncestor(op);
241  };
242  LogicalResult result = getBackwardSlice(value, &slice, sliceOptions);
243  assert(result.succeeded());
244  (void)result;
245 
246  // Check that none of the operands of the operations in the backward slice are
247  // loop iteration arguments, and neither is the value itself.
248  SmallPtrSet<Value, 8> iterCarriedValSet(llvm::from_range, iterCarriedArgs);
249  if (iterCarriedValSet.contains(value))
250  return true;
251 
252  for (Operation *op : slice)
253  for (Value operand : op->getOperands())
254  if (iterCarriedValSet.contains(operand))
255  return true;
256 
257  return false;
258 }
259 
260 /// Utility to match a generic reduction given a list of iteration-carried
261 /// arguments, `iterCarriedArgs` and the position of the potential reduction
262 /// argument within the list, `redPos`. If a reduction is matched, returns the
263 /// reduced value and the topologically-sorted list of combiner operations
264 /// involved in the reduction. Otherwise, returns a null value.
265 ///
266 /// The matching algorithm relies on the following invariants, which are subject
267 /// to change:
268 /// 1. The first combiner operation must be a binary operation with the
269 /// iteration-carried value and the reduced value as operands.
270 /// 2. The iteration-carried value and combiner operations must be side
271 /// effect-free, have single result and a single use.
272 /// 3. Combiner operations must be immediately nested in the region op
273 /// performing the reduction.
274 /// 4. Reduction def-use chain must end in a terminator op that yields the
275 /// next iteration/output values in the same order as the iteration-carried
276 /// values in `iterCarriedArgs`.
277 /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
278 /// of the region op performing the reduction.
279 ///
280 /// This utility is generic enough to detect reductions involving multiple
281 /// combiner operations (disabled for now) across multiple dialects, including
282 /// Linalg, Affine and SCF. For the sake of genericity, it does not return
283 /// specific enum values for the combiner operations since its goal is also
284 /// matching reductions without pre-defined semantics in core MLIR. It's up to
285 /// each client to make sense out of the list of combiner operations. It's also
286 /// up to each client to check for additional invariants on the expected
287 /// reductions not covered by this generic matching.
289  unsigned redPos,
290  SmallVectorImpl<Operation *> &combinerOps) {
291  assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
292 
293  BlockArgument redCarriedVal = iterCarriedArgs[redPos];
294  if (!redCarriedVal.hasOneUse())
295  return nullptr;
296 
297  // For now, the first combiner op must be a binary op.
298  Operation *combinerOp = *redCarriedVal.getUsers().begin();
299  if (combinerOp->getNumOperands() != 2)
300  return nullptr;
301  Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
302  ? combinerOp->getOperand(1)
303  : combinerOp->getOperand(0);
304 
305  Operation *redRegionOp =
306  iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
307  if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
308  return nullptr;
309 
310  // Traverse the def-use chain starting from the first combiner op until a
311  // terminator is found. Gather all the combiner ops along the way in
312  // topological order.
313  while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
314  if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
315  !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
316  return nullptr;
317 
318  combinerOps.push_back(combinerOp);
319  combinerOp = *combinerOp->getUsers().begin();
320  }
321 
322  // Limit matching to single combiner op until we can properly test reductions
323  // involving multiple combiners.
324  if (combinerOps.size() != 1)
325  return nullptr;
326 
327  // Check that the yielded value is in the same position as in
328  // `iterCarriedArgs`.
329  Operation *terminatorOp = combinerOp;
330  if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
331  return nullptr;
332 
333  return reducedVal;
334 }
static llvm::ManagedStatic< PassManagerOptions > options
static void processValue(Value value, LiveMap &liveMap)
static LogicalResult getBackwardSliceImpl(Operation *op, DenseSet< Operation * > &visited, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options)
static bool dependsOnCarriedVals(Value value, ArrayRef< BlockArgument > iterCarriedArgs, Operation *ancestorOp)
Returns true if value (transitively) depends on iteration-carried values of the given ancestorOp.
static void getForwardSliceImpl(Operation *op, DenseSet< Operation * > &visited, SetVector< Operation * > *forwardSlice, const SliceOptions::TransitiveFilter &filter=nullptr)
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class represents an operand of an operation.
Definition: Value.h:257
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:757
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
Include the generated interface declarations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
std::function< bool(Operation *)> TransitiveFilter
Type of the condition to limit the propagation of transitive use-defs.
Definition: SliceAnalysis.h:28
TransitiveFilter filter
Definition: SliceAnalysis.h:29