MLIR  16.0.0git
SliceAnalysis.cpp
Go to the documentation of this file.
1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 
20 ///
21 /// Implements Analysis functions specific to slicing in Function.
22 ///
23 
24 using namespace mlir;
25 
27  SetVector<Operation *> *forwardSlice,
28  TransitiveFilter filter) {
29  if (!op)
30  return;
31 
32  // Evaluate whether we should keep this use.
33  // This is useful in particular to implement scoping; i.e. return the
34  // transitive forwardSlice in the current scope.
35  if (filter && !filter(op))
36  return;
37 
38  for (Region &region : op->getRegions())
39  for (Block &block : region)
40  for (Operation &blockOp : block)
41  if (forwardSlice->count(&blockOp) == 0)
42  getForwardSliceImpl(&blockOp, forwardSlice, filter);
43  for (Value result : op->getResults()) {
44  for (Operation *userOp : result.getUsers())
45  if (forwardSlice->count(userOp) == 0)
46  getForwardSliceImpl(userOp, forwardSlice, filter);
47  }
48 
49  forwardSlice->insert(op);
50 }
51 
53  TransitiveFilter filter) {
54  getForwardSliceImpl(op, forwardSlice, filter);
55  // Don't insert the top level operation, we just queried on it and don't
56  // want it in the results.
57  forwardSlice->remove(op);
58 
59  // Reverse to get back the actual topological order.
60  // std::reverse does not work out of the box on SetVector and I want an
61  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
62  std::vector<Operation *> v(forwardSlice->takeVector());
63  forwardSlice->insert(v.rbegin(), v.rend());
64 }
65 
67  TransitiveFilter filter) {
68  for (Operation *user : root.getUsers())
69  getForwardSliceImpl(user, forwardSlice, filter);
70 
71  // Reverse to get back the actual topological order.
72  // std::reverse does not work out of the box on SetVector and I want an
73  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
74  std::vector<Operation *> v(forwardSlice->takeVector());
75  forwardSlice->insert(v.rbegin(), v.rend());
76 }
77 
79  SetVector<Operation *> *backwardSlice,
80  TransitiveFilter filter) {
81  if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
82  return;
83 
84  // Evaluate whether we should keep this def.
85  // This is useful in particular to implement scoping; i.e. return the
86  // transitive backwardSlice in the current scope.
87  if (filter && !filter(op))
88  return;
89 
90  for (const auto &en : llvm::enumerate(op->getOperands())) {
91  auto operand = en.value();
92  if (auto *definingOp = operand.getDefiningOp()) {
93  if (backwardSlice->count(definingOp) == 0)
94  getBackwardSliceImpl(definingOp, backwardSlice, filter);
95  } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
96  Block *block = blockArg.getOwner();
97  Operation *parentOp = block->getParentOp();
98  // TODO: determine whether we want to recurse backward into the other
99  // blocks of parentOp, which are not technically backward unless they flow
100  // into us. For now, just bail.
101  assert(parentOp->getNumRegions() == 1 &&
102  parentOp->getRegion(0).getBlocks().size() == 1);
103  if (backwardSlice->count(parentOp) == 0)
104  getBackwardSliceImpl(parentOp, backwardSlice, filter);
105  } else {
106  llvm_unreachable("No definingOp and not a block argument.");
107  }
108  }
109 
110  backwardSlice->insert(op);
111 }
112 
114  SetVector<Operation *> *backwardSlice,
115  TransitiveFilter filter) {
116  getBackwardSliceImpl(op, backwardSlice, filter);
117 
118  // Don't insert the top level operation, we just queried on it and don't
119  // want it in the results.
120  backwardSlice->remove(op);
121 }
122 
124  TransitiveFilter filter) {
125  if (Operation *definingOp = root.getDefiningOp()) {
126  getBackwardSlice(definingOp, backwardSlice, filter);
127  return;
128  }
129  Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
130  getBackwardSlice(bbAargOwner, backwardSlice, filter);
131 }
132 
134  TransitiveFilter backwardFilter,
135  TransitiveFilter forwardFilter) {
137  slice.insert(op);
138 
139  unsigned currentIndex = 0;
140  SetVector<Operation *> backwardSlice;
141  SetVector<Operation *> forwardSlice;
142  while (currentIndex != slice.size()) {
143  auto *currentOp = (slice)[currentIndex];
144  // Compute and insert the backwardSlice starting from currentOp.
145  backwardSlice.clear();
146  getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
147  slice.insert(backwardSlice.begin(), backwardSlice.end());
148 
149  // Compute and insert the forwardSlice starting from currentOp.
150  forwardSlice.clear();
151  getForwardSlice(currentOp, &forwardSlice, forwardFilter);
152  slice.insert(forwardSlice.begin(), forwardSlice.end());
153  ++currentIndex;
154  }
155  return topologicalSort(slice);
156 }
157 
158 namespace {
159 /// DFS post-order implementation that maintains a global count to work across
160 /// multiple invocations, to help implement topological sort on multi-root DAGs.
161 /// We traverse all operations but only record the ones that appear in
162 /// `toSort` for the final result.
163 struct DFSState {
164  DFSState(const SetVector<Operation *> &set)
165  : toSort(set), topologicalCounts(), seen() {}
166  const SetVector<Operation *> &toSort;
167  SmallVector<Operation *, 16> topologicalCounts;
169 };
170 } // namespace
171 
172 static void dfsPostorder(Operation *root, DFSState *state) {
173  SmallVector<Operation *> queue(1, root);
174  std::vector<Operation *> ops;
175  while (!queue.empty()) {
176  Operation *current = queue.pop_back_val();
177  ops.push_back(current);
178  for (Value result : current->getResults()) {
179  for (Operation *op : result.getUsers())
180  queue.push_back(op);
181  }
182  for (Region &region : current->getRegions()) {
183  for (Operation &op : region.getOps())
184  queue.push_back(&op);
185  }
186  }
187 
188  for (Operation *op : llvm::reverse(ops)) {
189  if (state->seen.insert(op).second && state->toSort.count(op) > 0)
190  state->topologicalCounts.push_back(op);
191  }
192 }
193 
196  if (toSort.empty()) {
197  return toSort;
198  }
199 
200  // Run from each root with global count and `seen` set.
201  DFSState state(toSort);
202  for (auto *s : toSort) {
203  assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
204  dfsPostorder(s, &state);
205  }
206 
207  // Reorder and return.
209  for (auto it = state.topologicalCounts.rbegin(),
210  eit = state.topologicalCounts.rend();
211  it != eit; ++it) {
212  res.insert(*it);
213  }
214  return res;
215 }
216 
217 /// Returns true if `value` (transitively) depends on iteration-carried values
218 /// of the given `ancestorOp`.
220  ArrayRef<BlockArgument> iterCarriedArgs,
221  Operation *ancestorOp) {
222  // Compute the backward slice of the value.
224  getBackwardSlice(value, &slice,
225  [&](Operation *op) { return !ancestorOp->isAncestor(op); });
226 
227  // Check that none of the operands of the operations in the backward slice are
228  // loop iteration arguments, and neither is the value itself.
229  SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
230  iterCarriedArgs.end());
231  if (iterCarriedValSet.contains(value))
232  return true;
233 
234  for (Operation *op : slice)
235  for (Value operand : op->getOperands())
236  if (iterCarriedValSet.contains(operand))
237  return true;
238 
239  return false;
240 }
241 
242 /// Utility to match a generic reduction given a list of iteration-carried
243 /// arguments, `iterCarriedArgs` and the position of the potential reduction
244 /// argument within the list, `redPos`. If a reduction is matched, returns the
245 /// reduced value and the topologically-sorted list of combiner operations
246 /// involved in the reduction. Otherwise, returns a null value.
247 ///
248 /// The matching algorithm relies on the following invariants, which are subject
249 /// to change:
250 /// 1. The first combiner operation must be a binary operation with the
251 /// iteration-carried value and the reduced value as operands.
252 /// 2. The iteration-carried value and combiner operations must be side
253 /// effect-free, have single result and a single use.
254 /// 3. Combiner operations must be immediately nested in the region op
255 /// performing the reduction.
256 /// 4. Reduction def-use chain must end in a terminator op that yields the
257 /// next iteration/output values in the same order as the iteration-carried
258 /// values in `iterCarriedArgs`.
259 /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
260 /// of the region op performing the reduction.
261 ///
262 /// This utility is generic enough to detect reductions involving multiple
263 /// combiner operations (disabled for now) across multiple dialects, including
264 /// Linalg, Affine and SCF. For the sake of genericity, it does not return
265 /// specific enum values for the combiner operations since its goal is also
266 /// matching reductions without pre-defined semantics in core MLIR. It's up to
267 /// each client to make sense out of the list of combiner operations. It's also
268 /// up to each client to check for additional invariants on the expected
269 /// reductions not covered by this generic matching.
271  unsigned redPos,
272  SmallVectorImpl<Operation *> &combinerOps) {
273  assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
274 
275  BlockArgument redCarriedVal = iterCarriedArgs[redPos];
276  if (!redCarriedVal.hasOneUse())
277  return nullptr;
278 
279  // For now, the first combiner op must be a binary op.
280  Operation *combinerOp = *redCarriedVal.getUsers().begin();
281  if (combinerOp->getNumOperands() != 2)
282  return nullptr;
283  Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
284  ? combinerOp->getOperand(1)
285  : combinerOp->getOperand(0);
286 
287  Operation *redRegionOp =
288  iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
289  if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
290  return nullptr;
291 
292  // Traverse the def-use chain starting from the first combiner op until a
293  // terminator is found. Gather all the combiner ops along the way in
294  // topological order.
295  while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
296  if (!MemoryEffectOpInterface::hasNoEffect(combinerOp) ||
297  combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() ||
298  combinerOp->getParentOp() != redRegionOp)
299  return nullptr;
300 
301  combinerOps.push_back(combinerOp);
302  combinerOp = *combinerOp->getUsers().begin();
303  }
304 
305  // Limit matching to single combiner op until we can properly test reductions
306  // involving multiple combiners.
307  if (combinerOps.size() != 1)
308  return nullptr;
309 
310  // Check that the yielded value is in the same position as in
311  // `iterCarriedArgs`.
312  Operation *terminatorOp = combinerOp;
313  if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
314  return nullptr;
315 
316  return reducedVal;
317 }
Include the generated interface declarations.
static void getForwardSliceImpl(Operation *op, SetVector< Operation *> *forwardSlice, TransitiveFilter filter)
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:194
BlockListType & getBlocks()
Definition: Region.h:45
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
Block represents an ordered list of Operations.
Definition: Block.h:29
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:626
Value getOperand(unsigned idx)
Definition: Operation.h:267
void getBackwardSlice(Operation *op, SetVector< Operation *> *backwardSlice, TransitiveFilter filter=nullptr)
Fills backwardSlice with the computed backward slice (i.e.
SetVector< Operation * > getSlice(Operation *op, TransitiveFilter backwardFilter=nullptr, TransitiveFilter forwardFilter=nullptr)
Iteratively computes backward slices and forward slices until a fixed point is reached.
unsigned getNumOperands()
Definition: Operation.h:263
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:536
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:707
user_range getUsers() const
Definition: Value.h:213
static constexpr const bool value
SetVector< Operation * > topologicalSort(const SetVector< Operation *> &toSort)
Multi-root DAG topological sort.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
void getForwardSlice(Operation *op, SetVector< Operation *> *forwardSlice, TransitiveFilter filter=nullptr)
Fills forwardSlice with the computed forward slice (i.e.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation *> &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
This class represents an argument of a Block.
Definition: Value.h:300
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static void dfsPostorder(Operation *root, DFSState *state)
static void getBackwardSliceImpl(Operation *op, SetVector< Operation *> *backwardSlice, TransitiveFilter filter)
This class provides the API for ops that are known to be isolated from above.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
U cast() const
Definition: Value.h:108
static bool dependsOnCarriedVals(Value value, ArrayRef< BlockArgument > iterCarriedArgs, Operation *ancestorOp)
Returns true if value (transitively) depends on iteration-carried values of the given ancestorOp...
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:650
result_range getResults()
Definition: Operation.h:332
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:486