MLIR  20.0.0git
SliceAnalysis.cpp
Go to the documentation of this file.
1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Block.h"
16 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 
23 ///
24 /// Implements Analysis functions specific to slicing in Function.
25 ///
26 
27 using namespace mlir;
28 
29 static void
31  const SliceOptions::TransitiveFilter &filter = nullptr) {
32  if (!op)
33  return;
34 
35  // Evaluate whether we should keep this use.
36  // This is useful in particular to implement scoping; i.e. return the
37  // transitive forwardSlice in the current scope.
38  if (filter && !filter(op))
39  return;
40 
41  for (Region &region : op->getRegions())
42  for (Block &block : region)
43  for (Operation &blockOp : block)
44  if (forwardSlice->count(&blockOp) == 0)
45  getForwardSliceImpl(&blockOp, forwardSlice, filter);
46  for (Value result : op->getResults()) {
47  for (Operation *userOp : result.getUsers())
48  if (forwardSlice->count(userOp) == 0)
49  getForwardSliceImpl(userOp, forwardSlice, filter);
50  }
51 
52  forwardSlice->insert(op);
53 }
54 
57  getForwardSliceImpl(op, forwardSlice, options.filter);
58  if (!options.inclusive) {
59  // Don't insert the top level operation, we just queried on it and don't
60  // want it in the results.
61  forwardSlice->remove(op);
62  }
63 
64  // Reverse to get back the actual topological order.
65  // std::reverse does not work out of the box on SetVector and I want an
66  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
67  SmallVector<Operation *, 0> v(forwardSlice->takeVector());
68  forwardSlice->insert(v.rbegin(), v.rend());
69 }
70 
72  const SliceOptions &options) {
73  for (Operation *user : root.getUsers())
74  getForwardSliceImpl(user, forwardSlice, options.filter);
75 
76  // Reverse to get back the actual topological order.
77  // std::reverse does not work out of the box on SetVector and I want an
78  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
79  SmallVector<Operation *, 0> v(forwardSlice->takeVector());
80  forwardSlice->insert(v.rbegin(), v.rend());
81 }
82 
84  SetVector<Operation *> *backwardSlice,
86  if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
87  return;
88 
89  // Evaluate whether we should keep this def.
90  // This is useful in particular to implement scoping; i.e. return the
91  // transitive backwardSlice in the current scope.
92  if (options.filter && !options.filter(op))
93  return;
94 
95  auto processValue = [&](Value value) {
96  if (auto *definingOp = value.getDefiningOp()) {
97  if (backwardSlice->count(definingOp) == 0)
98  getBackwardSliceImpl(definingOp, backwardSlice, options);
99  } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
100  if (options.omitBlockArguments)
101  return;
102 
103  Block *block = blockArg.getOwner();
104  Operation *parentOp = block->getParentOp();
105  // TODO: determine whether we want to recurse backward into the other
106  // blocks of parentOp, which are not technically backward unless they flow
107  // into us. For now, just bail.
108  if (parentOp && backwardSlice->count(parentOp) == 0) {
109  assert(parentOp->getNumRegions() == 1 &&
110  parentOp->getRegion(0).getBlocks().size() == 1);
111  getBackwardSliceImpl(parentOp, backwardSlice, options);
112  }
113  } else {
114  llvm_unreachable("No definingOp and not a block argument.");
115  }
116  };
117 
118  if (!options.omitUsesFromAbove) {
119  llvm::for_each(op->getRegions(), [&](Region &region) {
120  // Walk this region recursively to collect the regions that descend from
121  // this op's nested regions (inclusive).
122  SmallPtrSet<Region *, 4> descendents;
123  region.walk(
124  [&](Region *childRegion) { descendents.insert(childRegion); });
125  region.walk([&](Operation *op) {
126  for (OpOperand &operand : op->getOpOperands()) {
127  if (!descendents.contains(operand.get().getParentRegion()))
128  processValue(operand.get());
129  }
130  });
131  });
132  }
133  llvm::for_each(op->getOperands(), processValue);
134 
135  backwardSlice->insert(op);
136 }
137 
139  SetVector<Operation *> *backwardSlice,
140  const BackwardSliceOptions &options) {
141  getBackwardSliceImpl(op, backwardSlice, options);
142 
143  if (!options.inclusive) {
144  // Don't insert the top level operation, we just queried on it and don't
145  // want it in the results.
146  backwardSlice->remove(op);
147  }
148 }
149 
151  const BackwardSliceOptions &options) {
152  if (Operation *definingOp = root.getDefiningOp()) {
153  getBackwardSlice(definingOp, backwardSlice, options);
154  return;
155  }
156  Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
157  getBackwardSlice(bbAargOwner, backwardSlice, options);
158 }
159 
161 mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
162  const ForwardSliceOptions &forwardSliceOptions) {
164  slice.insert(op);
165 
166  unsigned currentIndex = 0;
167  SetVector<Operation *> backwardSlice;
168  SetVector<Operation *> forwardSlice;
169  while (currentIndex != slice.size()) {
170  auto *currentOp = (slice)[currentIndex];
171  // Compute and insert the backwardSlice starting from currentOp.
172  backwardSlice.clear();
173  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
174  slice.insert(backwardSlice.begin(), backwardSlice.end());
175 
176  // Compute and insert the forwardSlice starting from currentOp.
177  forwardSlice.clear();
178  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
179  slice.insert(forwardSlice.begin(), forwardSlice.end());
180  ++currentIndex;
181  }
182  return topologicalSort(slice);
183 }
184 
185 /// Returns true if `value` (transitively) depends on iteration-carried values
186 /// of the given `ancestorOp`.
187 static bool dependsOnCarriedVals(Value value,
188  ArrayRef<BlockArgument> iterCarriedArgs,
189  Operation *ancestorOp) {
190  // Compute the backward slice of the value.
192  BackwardSliceOptions sliceOptions;
193  sliceOptions.filter = [&](Operation *op) {
194  return !ancestorOp->isAncestor(op);
195  };
196  getBackwardSlice(value, &slice, sliceOptions);
197 
198  // Check that none of the operands of the operations in the backward slice are
199  // loop iteration arguments, and neither is the value itself.
200  SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
201  iterCarriedArgs.end());
202  if (iterCarriedValSet.contains(value))
203  return true;
204 
205  for (Operation *op : slice)
206  for (Value operand : op->getOperands())
207  if (iterCarriedValSet.contains(operand))
208  return true;
209 
210  return false;
211 }
212 
213 /// Utility to match a generic reduction given a list of iteration-carried
214 /// arguments, `iterCarriedArgs` and the position of the potential reduction
215 /// argument within the list, `redPos`. If a reduction is matched, returns the
216 /// reduced value and the topologically-sorted list of combiner operations
217 /// involved in the reduction. Otherwise, returns a null value.
218 ///
219 /// The matching algorithm relies on the following invariants, which are subject
220 /// to change:
221 /// 1. The first combiner operation must be a binary operation with the
222 /// iteration-carried value and the reduced value as operands.
223 /// 2. The iteration-carried value and combiner operations must be side
224 /// effect-free, have single result and a single use.
225 /// 3. Combiner operations must be immediately nested in the region op
226 /// performing the reduction.
227 /// 4. Reduction def-use chain must end in a terminator op that yields the
228 /// next iteration/output values in the same order as the iteration-carried
229 /// values in `iterCarriedArgs`.
230 /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
231 /// of the region op performing the reduction.
232 ///
233 /// This utility is generic enough to detect reductions involving multiple
234 /// combiner operations (disabled for now) across multiple dialects, including
235 /// Linalg, Affine and SCF. For the sake of genericity, it does not return
236 /// specific enum values for the combiner operations since its goal is also
237 /// matching reductions without pre-defined semantics in core MLIR. It's up to
238 /// each client to make sense out of the list of combiner operations. It's also
239 /// up to each client to check for additional invariants on the expected
240 /// reductions not covered by this generic matching.
242  unsigned redPos,
243  SmallVectorImpl<Operation *> &combinerOps) {
244  assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
245 
246  BlockArgument redCarriedVal = iterCarriedArgs[redPos];
247  if (!redCarriedVal.hasOneUse())
248  return nullptr;
249 
250  // For now, the first combiner op must be a binary op.
251  Operation *combinerOp = *redCarriedVal.getUsers().begin();
252  if (combinerOp->getNumOperands() != 2)
253  return nullptr;
254  Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
255  ? combinerOp->getOperand(1)
256  : combinerOp->getOperand(0);
257 
258  Operation *redRegionOp =
259  iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
260  if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
261  return nullptr;
262 
263  // Traverse the def-use chain starting from the first combiner op until a
264  // terminator is found. Gather all the combiner ops along the way in
265  // topological order.
266  while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
267  if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
268  !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
269  return nullptr;
270 
271  combinerOps.push_back(combinerOp);
272  combinerOp = *combinerOp->getUsers().begin();
273  }
274 
275  // Limit matching to single combiner op until we can properly test reductions
276  // involving multiple combiners.
277  if (combinerOps.size() != 1)
278  return nullptr;
279 
280  // Check that the yielded value is in the same position as in
281  // `iterCarriedArgs`.
282  Operation *terminatorOp = combinerOp;
283  if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
284  return nullptr;
285 
286  return reducedVal;
287 }
static llvm::ManagedStatic< PassManagerOptions > options
static void processValue(Value value, LiveMap &liveMap)
static void getForwardSliceImpl(Operation *op, SetVector< Operation * > *forwardSlice, const SliceOptions::TransitiveFilter &filter=nullptr)
static bool dependsOnCarriedVals(Value value, ArrayRef< BlockArgument > iterCarriedArgs, Operation *ancestorOp)
Returns true if value (transitively) depends on iteration-carried values of the given ancestorOp.
static void getBackwardSliceImpl(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options)
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class represents an operand of an operation.
Definition: Value.h:267
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:753
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
std::function< bool(Operation *)> TransitiveFilter
Type of the condition to limit the propagation of transitive use-defs.
Definition: SliceAnalysis.h:29
TransitiveFilter filter
Definition: SliceAnalysis.h:30