MLIR  17.0.0git
SliceAnalysis.cpp
Go to the documentation of this file.
1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Operation.h"
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 
21 ///
22 /// Implements Analysis functions specific to slicing in Function.
23 ///
24 
25 using namespace mlir;
26 
28  SetVector<Operation *> *forwardSlice,
29  TransitiveFilter filter) {
30  if (!op)
31  return;
32 
33  // Evaluate whether we should keep this use.
34  // This is useful in particular to implement scoping; i.e. return the
35  // transitive forwardSlice in the current scope.
36  if (filter && !filter(op))
37  return;
38 
39  for (Region &region : op->getRegions())
40  for (Block &block : region)
41  for (Operation &blockOp : block)
42  if (forwardSlice->count(&blockOp) == 0)
43  getForwardSliceImpl(&blockOp, forwardSlice, filter);
44  for (Value result : op->getResults()) {
45  for (Operation *userOp : result.getUsers())
46  if (forwardSlice->count(userOp) == 0)
47  getForwardSliceImpl(userOp, forwardSlice, filter);
48  }
49 
50  forwardSlice->insert(op);
51 }
52 
54  TransitiveFilter filter, bool inclusive) {
55  getForwardSliceImpl(op, forwardSlice, filter);
56  if (!inclusive) {
57  // Don't insert the top level operation, we just queried on it and don't
58  // want it in the results.
59  forwardSlice->remove(op);
60  }
61 
62  // Reverse to get back the actual topological order.
63  // std::reverse does not work out of the box on SetVector and I want an
64  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
65  std::vector<Operation *> v(forwardSlice->takeVector());
66  forwardSlice->insert(v.rbegin(), v.rend());
67 }
68 
70  TransitiveFilter filter, bool inclusive) {
71  for (Operation *user : root.getUsers())
72  getForwardSliceImpl(user, forwardSlice, filter);
73 
74  // Reverse to get back the actual topological order.
75  // std::reverse does not work out of the box on SetVector and I want an
76  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
77  std::vector<Operation *> v(forwardSlice->takeVector());
78  forwardSlice->insert(v.rbegin(), v.rend());
79 }
80 
82  SetVector<Operation *> *backwardSlice,
83  TransitiveFilter filter) {
84  if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
85  return;
86 
87  // Evaluate whether we should keep this def.
88  // This is useful in particular to implement scoping; i.e. return the
89  // transitive backwardSlice in the current scope.
90  if (filter && !filter(op))
91  return;
92 
93  for (const auto &en : llvm::enumerate(op->getOperands())) {
94  auto operand = en.value();
95  if (auto *definingOp = operand.getDefiningOp()) {
96  if (backwardSlice->count(definingOp) == 0)
97  getBackwardSliceImpl(definingOp, backwardSlice, filter);
98  } else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
99  Block *block = blockArg.getOwner();
100  Operation *parentOp = block->getParentOp();
101  // TODO: determine whether we want to recurse backward into the other
102  // blocks of parentOp, which are not technically backward unless they flow
103  // into us. For now, just bail.
104  if (parentOp && backwardSlice->count(parentOp) == 0) {
105  assert(parentOp->getNumRegions() == 1 &&
106  parentOp->getRegion(0).getBlocks().size() == 1);
107  getBackwardSliceImpl(parentOp, backwardSlice, filter);
108  }
109  } else {
110  llvm_unreachable("No definingOp and not a block argument.");
111  }
112  }
113 
114  backwardSlice->insert(op);
115 }
116 
118  SetVector<Operation *> *backwardSlice,
119  TransitiveFilter filter, bool inclusive) {
120  getBackwardSliceImpl(op, backwardSlice, filter);
121 
122  if (!inclusive) {
123  // Don't insert the top level operation, we just queried on it and don't
124  // want it in the results.
125  backwardSlice->remove(op);
126  }
127 }
128 
130  TransitiveFilter filter, bool inclusive) {
131  if (Operation *definingOp = root.getDefiningOp()) {
132  getBackwardSlice(definingOp, backwardSlice, filter, inclusive);
133  return;
134  }
135  Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
136  getBackwardSlice(bbAargOwner, backwardSlice, filter, inclusive);
137 }
138 
140  TransitiveFilter backwardFilter,
141  TransitiveFilter forwardFilter,
142  bool inclusive) {
144  slice.insert(op);
145 
146  unsigned currentIndex = 0;
147  SetVector<Operation *> backwardSlice;
148  SetVector<Operation *> forwardSlice;
149  while (currentIndex != slice.size()) {
150  auto *currentOp = (slice)[currentIndex];
151  // Compute and insert the backwardSlice starting from currentOp.
152  backwardSlice.clear();
153  getBackwardSlice(currentOp, &backwardSlice, backwardFilter, inclusive);
154  slice.insert(backwardSlice.begin(), backwardSlice.end());
155 
156  // Compute and insert the forwardSlice starting from currentOp.
157  forwardSlice.clear();
158  getForwardSlice(currentOp, &forwardSlice, forwardFilter, inclusive);
159  slice.insert(forwardSlice.begin(), forwardSlice.end());
160  ++currentIndex;
161  }
162  return topologicalSort(slice);
163 }
164 
165 namespace {
166 /// DFS post-order implementation that maintains a global count to work across
167 /// multiple invocations, to help implement topological sort on multi-root DAGs.
168 /// We traverse all operations but only record the ones that appear in
169 /// `toSort` for the final result.
170 struct DFSState {
171  DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
172  const SetVector<Operation *> &toSort;
173  SmallVector<Operation *, 16> topologicalCounts;
175 };
176 } // namespace
177 
178 static void dfsPostorder(Operation *root, DFSState *state) {
179  SmallVector<Operation *> queue(1, root);
180  std::vector<Operation *> ops;
181  while (!queue.empty()) {
182  Operation *current = queue.pop_back_val();
183  ops.push_back(current);
184  for (Operation *op : current->getUsers())
185  queue.push_back(op);
186  for (Region &region : current->getRegions()) {
187  for (Operation &op : region.getOps())
188  queue.push_back(&op);
189  }
190  }
191 
192  for (Operation *op : llvm::reverse(ops)) {
193  if (state->seen.insert(op).second && state->toSort.count(op) > 0)
194  state->topologicalCounts.push_back(op);
195  }
196 }
197 
200  if (toSort.empty()) {
201  return toSort;
202  }
203 
204  // Run from each root with global count and `seen` set.
205  DFSState state(toSort);
206  for (auto *s : toSort) {
207  assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
208  dfsPostorder(s, &state);
209  }
210 
211  // Reorder and return.
213  for (auto it = state.topologicalCounts.rbegin(),
214  eit = state.topologicalCounts.rend();
215  it != eit; ++it) {
216  res.insert(*it);
217  }
218  return res;
219 }
220 
221 /// Returns true if `value` (transitively) depends on iteration-carried values
222 /// of the given `ancestorOp`.
223 static bool dependsOnCarriedVals(Value value,
224  ArrayRef<BlockArgument> iterCarriedArgs,
225  Operation *ancestorOp) {
226  // Compute the backward slice of the value.
228  getBackwardSlice(value, &slice,
229  [&](Operation *op) { return !ancestorOp->isAncestor(op); });
230 
231  // Check that none of the operands of the operations in the backward slice are
232  // loop iteration arguments, and neither is the value itself.
233  SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
234  iterCarriedArgs.end());
235  if (iterCarriedValSet.contains(value))
236  return true;
237 
238  for (Operation *op : slice)
239  for (Value operand : op->getOperands())
240  if (iterCarriedValSet.contains(operand))
241  return true;
242 
243  return false;
244 }
245 
246 /// Utility to match a generic reduction given a list of iteration-carried
247 /// arguments, `iterCarriedArgs` and the position of the potential reduction
248 /// argument within the list, `redPos`. If a reduction is matched, returns the
249 /// reduced value and the topologically-sorted list of combiner operations
250 /// involved in the reduction. Otherwise, returns a null value.
251 ///
252 /// The matching algorithm relies on the following invariants, which are subject
253 /// to change:
254 /// 1. The first combiner operation must be a binary operation with the
255 /// iteration-carried value and the reduced value as operands.
256 /// 2. The iteration-carried value and combiner operations must be side
257 /// effect-free, have single result and a single use.
258 /// 3. Combiner operations must be immediately nested in the region op
259 /// performing the reduction.
260 /// 4. Reduction def-use chain must end in a terminator op that yields the
261 /// next iteration/output values in the same order as the iteration-carried
262 /// values in `iterCarriedArgs`.
263 /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
264 /// of the region op performing the reduction.
265 ///
266 /// This utility is generic enough to detect reductions involving multiple
267 /// combiner operations (disabled for now) across multiple dialects, including
268 /// Linalg, Affine and SCF. For the sake of genericity, it does not return
269 /// specific enum values for the combiner operations since its goal is also
270 /// matching reductions without pre-defined semantics in core MLIR. It's up to
271 /// each client to make sense out of the list of combiner operations. It's also
272 /// up to each client to check for additional invariants on the expected
273 /// reductions not covered by this generic matching.
275  unsigned redPos,
276  SmallVectorImpl<Operation *> &combinerOps) {
277  assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
278 
279  BlockArgument redCarriedVal = iterCarriedArgs[redPos];
280  if (!redCarriedVal.hasOneUse())
281  return nullptr;
282 
283  // For now, the first combiner op must be a binary op.
284  Operation *combinerOp = *redCarriedVal.getUsers().begin();
285  if (combinerOp->getNumOperands() != 2)
286  return nullptr;
287  Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
288  ? combinerOp->getOperand(1)
289  : combinerOp->getOperand(0);
290 
291  Operation *redRegionOp =
292  iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
293  if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
294  return nullptr;
295 
296  // Traverse the def-use chain starting from the first combiner op until a
297  // terminator is found. Gather all the combiner ops along the way in
298  // topological order.
299  while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
300  if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
301  !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
302  return nullptr;
303 
304  combinerOps.push_back(combinerOp);
305  combinerOp = *combinerOp->getUsers().begin();
306  }
307 
308  // Limit matching to single combiner op until we can properly test reductions
309  // involving multiple combiners.
310  if (combinerOps.size() != 1)
311  return nullptr;
312 
313  // Check that the yielded value is in the same position as in
314  // `iterCarriedArgs`.
315  Operation *terminatorOp = combinerOp;
316  if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
317  return nullptr;
318 
319  return reducedVal;
320 }
static void getForwardSliceImpl(Operation *op, SetVector< Operation * > *forwardSlice, TransitiveFilter filter)
static bool dependsOnCarriedVals(Value value, ArrayRef< BlockArgument > iterCarriedArgs, Operation *ancestorOp)
Returns true if value (transitively) depends on iteration-carried values of the given ancestorOp.
static void getBackwardSliceImpl(Operation *op, SetVector< Operation * > *backwardSlice, TransitiveFilter filter)
static void dfsPostorder(Operation *root, DFSState *state)
This class represents an argument of a Block.
Definition: Value.h:310
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:753
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:690
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:698
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:790
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:635
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:648
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:638
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:814
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
user_range getUsers() const
Definition: Value.h:222
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:209
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:262
This header declares functions that assit transformations in the MemRef dialect.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
SetVector< Operation * > getSlice(Operation *op, TransitiveFilter backwardFilter=nullptr, TransitiveFilter forwardFilter=nullptr, bool inclusive=false)
Iteratively computes backward slices and forward slices until a fixed point is reached.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, TransitiveFilter filter=nullptr, bool inclusive=false)
Fills forwardSlice with the computed forward slice (i.e.
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, TransitiveFilter filter=nullptr, bool inclusive=false)
Fills backwardSlice with the computed backward slice (i.e.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.