MLIR  18.0.0git
SliceAnalysis.cpp
Go to the documentation of this file.
1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Operation.h"
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 
21 ///
22 /// Implements Analysis functions specific to slicing in Function.
23 ///
24 
25 using namespace mlir;
26 
27 static void
29  const SliceOptions::TransitiveFilter &filter = nullptr) {
30  if (!op)
31  return;
32 
33  // Evaluate whether we should keep this use.
34  // This is useful in particular to implement scoping; i.e. return the
35  // transitive forwardSlice in the current scope.
36  if (filter && !filter(op))
37  return;
38 
39  for (Region &region : op->getRegions())
40  for (Block &block : region)
41  for (Operation &blockOp : block)
42  if (forwardSlice->count(&blockOp) == 0)
43  getForwardSliceImpl(&blockOp, forwardSlice, filter);
44  for (Value result : op->getResults()) {
45  for (Operation *userOp : result.getUsers())
46  if (forwardSlice->count(userOp) == 0)
47  getForwardSliceImpl(userOp, forwardSlice, filter);
48  }
49 
50  forwardSlice->insert(op);
51 }
52 
55  getForwardSliceImpl(op, forwardSlice, options.filter);
56  if (!options.inclusive) {
57  // Don't insert the top level operation, we just queried on it and don't
58  // want it in the results.
59  forwardSlice->remove(op);
60  }
61 
62  // Reverse to get back the actual topological order.
63  // std::reverse does not work out of the box on SetVector and I want an
64  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
65  SmallVector<Operation *, 0> v(forwardSlice->takeVector());
66  forwardSlice->insert(v.rbegin(), v.rend());
67 }
68 
70  const SliceOptions &options) {
71  for (Operation *user : root.getUsers())
72  getForwardSliceImpl(user, forwardSlice, options.filter);
73 
74  // Reverse to get back the actual topological order.
75  // std::reverse does not work out of the box on SetVector and I want an
76  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
77  SmallVector<Operation *, 0> v(forwardSlice->takeVector());
78  forwardSlice->insert(v.rbegin(), v.rend());
79 }
80 
82  SetVector<Operation *> *backwardSlice,
84  if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
85  return;
86 
87  // Evaluate whether we should keep this def.
88  // This is useful in particular to implement scoping; i.e. return the
89  // transitive backwardSlice in the current scope.
90  if (options.filter && !options.filter(op))
91  return;
92 
93  for (const auto &en : llvm::enumerate(op->getOperands())) {
94  auto operand = en.value();
95  if (auto *definingOp = operand.getDefiningOp()) {
96  if (backwardSlice->count(definingOp) == 0)
97  getBackwardSliceImpl(definingOp, backwardSlice, options);
98  } else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
99  if (options.omitBlockArguments)
100  continue;
101 
102  Block *block = blockArg.getOwner();
103  Operation *parentOp = block->getParentOp();
104  // TODO: determine whether we want to recurse backward into the other
105  // blocks of parentOp, which are not technically backward unless they flow
106  // into us. For now, just bail.
107  if (parentOp && backwardSlice->count(parentOp) == 0) {
108  assert(parentOp->getNumRegions() == 1 &&
109  parentOp->getRegion(0).getBlocks().size() == 1);
110  getBackwardSliceImpl(parentOp, backwardSlice, options);
111  }
112  } else {
113  llvm_unreachable("No definingOp and not a block argument.");
114  }
115  }
116 
117  backwardSlice->insert(op);
118 }
119 
121  SetVector<Operation *> *backwardSlice,
122  const BackwardSliceOptions &options) {
123  getBackwardSliceImpl(op, backwardSlice, options);
124 
125  if (!options.inclusive) {
126  // Don't insert the top level operation, we just queried on it and don't
127  // want it in the results.
128  backwardSlice->remove(op);
129  }
130 }
131 
133  const BackwardSliceOptions &options) {
134  if (Operation *definingOp = root.getDefiningOp()) {
135  getBackwardSlice(definingOp, backwardSlice, options);
136  return;
137  }
138  Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
139  getBackwardSlice(bbAargOwner, backwardSlice, options);
140 }
141 
143 mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
144  const ForwardSliceOptions &forwardSliceOptions) {
146  slice.insert(op);
147 
148  unsigned currentIndex = 0;
149  SetVector<Operation *> backwardSlice;
150  SetVector<Operation *> forwardSlice;
151  while (currentIndex != slice.size()) {
152  auto *currentOp = (slice)[currentIndex];
153  // Compute and insert the backwardSlice starting from currentOp.
154  backwardSlice.clear();
155  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
156  slice.insert(backwardSlice.begin(), backwardSlice.end());
157 
158  // Compute and insert the forwardSlice starting from currentOp.
159  forwardSlice.clear();
160  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
161  slice.insert(forwardSlice.begin(), forwardSlice.end());
162  ++currentIndex;
163  }
164  return topologicalSort(slice);
165 }
166 
167 namespace {
168 /// DFS post-order implementation that maintains a global count to work across
169 /// multiple invocations, to help implement topological sort on multi-root DAGs.
170 /// We traverse all operations but only record the ones that appear in
171 /// `toSort` for the final result.
172 struct DFSState {
173  DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
174  const SetVector<Operation *> &toSort;
175  SmallVector<Operation *, 16> topologicalCounts;
177 };
178 } // namespace
179 
180 static void dfsPostorder(Operation *root, DFSState *state) {
181  SmallVector<Operation *> queue(1, root);
182  std::vector<Operation *> ops;
183  while (!queue.empty()) {
184  Operation *current = queue.pop_back_val();
185  ops.push_back(current);
186  for (Operation *op : current->getUsers())
187  queue.push_back(op);
188  for (Region &region : current->getRegions()) {
189  for (Operation &op : region.getOps())
190  queue.push_back(&op);
191  }
192  }
193 
194  for (Operation *op : llvm::reverse(ops)) {
195  if (state->seen.insert(op).second && state->toSort.count(op) > 0)
196  state->topologicalCounts.push_back(op);
197  }
198 }
199 
202  if (toSort.empty()) {
203  return toSort;
204  }
205 
206  // Run from each root with global count and `seen` set.
207  DFSState state(toSort);
208  for (auto *s : toSort) {
209  assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
210  dfsPostorder(s, &state);
211  }
212 
213  // Reorder and return.
215  for (auto it = state.topologicalCounts.rbegin(),
216  eit = state.topologicalCounts.rend();
217  it != eit; ++it) {
218  res.insert(*it);
219  }
220  return res;
221 }
222 
223 /// Returns true if `value` (transitively) depends on iteration-carried values
224 /// of the given `ancestorOp`.
225 static bool dependsOnCarriedVals(Value value,
226  ArrayRef<BlockArgument> iterCarriedArgs,
227  Operation *ancestorOp) {
228  // Compute the backward slice of the value.
230  BackwardSliceOptions sliceOptions;
231  sliceOptions.filter = [&](Operation *op) {
232  return !ancestorOp->isAncestor(op);
233  };
234  getBackwardSlice(value, &slice, sliceOptions);
235 
236  // Check that none of the operands of the operations in the backward slice are
237  // loop iteration arguments, and neither is the value itself.
238  SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
239  iterCarriedArgs.end());
240  if (iterCarriedValSet.contains(value))
241  return true;
242 
243  for (Operation *op : slice)
244  for (Value operand : op->getOperands())
245  if (iterCarriedValSet.contains(operand))
246  return true;
247 
248  return false;
249 }
250 
251 /// Utility to match a generic reduction given a list of iteration-carried
252 /// arguments, `iterCarriedArgs` and the position of the potential reduction
253 /// argument within the list, `redPos`. If a reduction is matched, returns the
254 /// reduced value and the topologically-sorted list of combiner operations
255 /// involved in the reduction. Otherwise, returns a null value.
256 ///
257 /// The matching algorithm relies on the following invariants, which are subject
258 /// to change:
259 /// 1. The first combiner operation must be a binary operation with the
260 /// iteration-carried value and the reduced value as operands.
261 /// 2. The iteration-carried value and combiner operations must be side
262 /// effect-free, have single result and a single use.
263 /// 3. Combiner operations must be immediately nested in the region op
264 /// performing the reduction.
265 /// 4. Reduction def-use chain must end in a terminator op that yields the
266 /// next iteration/output values in the same order as the iteration-carried
267 /// values in `iterCarriedArgs`.
268 /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
269 /// of the region op performing the reduction.
270 ///
271 /// This utility is generic enough to detect reductions involving multiple
272 /// combiner operations (disabled for now) across multiple dialects, including
273 /// Linalg, Affine and SCF. For the sake of genericity, it does not return
274 /// specific enum values for the combiner operations since its goal is also
275 /// matching reductions without pre-defined semantics in core MLIR. It's up to
276 /// each client to make sense out of the list of combiner operations. It's also
277 /// up to each client to check for additional invariants on the expected
278 /// reductions not covered by this generic matching.
280  unsigned redPos,
281  SmallVectorImpl<Operation *> &combinerOps) {
282  assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
283 
284  BlockArgument redCarriedVal = iterCarriedArgs[redPos];
285  if (!redCarriedVal.hasOneUse())
286  return nullptr;
287 
288  // For now, the first combiner op must be a binary op.
289  Operation *combinerOp = *redCarriedVal.getUsers().begin();
290  if (combinerOp->getNumOperands() != 2)
291  return nullptr;
292  Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
293  ? combinerOp->getOperand(1)
294  : combinerOp->getOperand(0);
295 
296  Operation *redRegionOp =
297  iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
298  if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
299  return nullptr;
300 
301  // Traverse the def-use chain starting from the first combiner op until a
302  // terminator is found. Gather all the combiner ops along the way in
303  // topological order.
304  while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
305  if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
306  !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
307  return nullptr;
308 
309  combinerOps.push_back(combinerOp);
310  combinerOp = *combinerOp->getUsers().begin();
311  }
312 
313  // Limit matching to single combiner op until we can properly test reductions
314  // involving multiple combiners.
315  if (combinerOps.size() != 1)
316  return nullptr;
317 
318  // Check that the yielded value is in the same position as in
319  // `iterCarriedArgs`.
320  Operation *terminatorOp = combinerOp;
321  if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
322  return nullptr;
323 
324  return reducedVal;
325 }
static llvm::ManagedStatic< PassManagerOptions > options
static void getForwardSliceImpl(Operation *op, SetVector< Operation * > *forwardSlice, const SliceOptions::TransitiveFilter &filter=nullptr)
static bool dependsOnCarriedVals(Value value, ArrayRef< BlockArgument > iterCarriedArgs, Operation *ancestorOp)
Returns true if value (transitively) depends on iteration-carried values of the given ancestorOp.
static void getBackwardSliceImpl(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options)
static void dfsPostorder(Operation *root, DFSState *state)
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:762
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:736
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:828
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:224
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
std::function< bool(Operation *)> TransitiveFilter
Type of the condition to limit the propagation of transitive use-defs.
Definition: SliceAnalysis.h:29
TransitiveFilter filter
Definition: SliceAnalysis.h:30