MLIR  21.0.0git
SliceMatchers.h
Go to the documentation of this file.
1 //===- SliceMatchers.h - Matchers for slicing analysis ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines slicing-analysis matchers that extend and abstract the
10 // core implementations from `SliceAnalysis.h`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
15 #define MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
16 
18 #include "mlir/IR/Operation.h"
19 
20 /// Computes the backward-slice of all transitive defs reachable from `rootOp`,
21 /// if `innerMatcher` matches. The traversal stops once the desired depth level
22 /// is reached.
23 ///
24 /// Example: starting from node 9, assuming the matcher
25 /// computes the slice for the first two depth levels:
26 /// ============================
27 /// 1 2 3 4
28 /// |_______| |______|
29 /// | | |
30 /// | 5 6
31 /// |___|_____________|
32 /// | |
33 /// 7 8
34 /// |_______________|
35 /// |
36 /// 9
37 ///
38 /// Assuming all local orders match the numbering order:
39 /// {5, 7, 6, 8, 9}
40 namespace mlir::query::matcher {
41 
42 template <typename Matcher>
44 public:
45  BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
46  bool omitBlockArguments, bool omitUsesFromAbove)
47  : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
48  inclusive(inclusive), omitBlockArguments(omitBlockArguments),
49  omitUsesFromAbove(omitUsesFromAbove) {}
50 
51  bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
53  options.inclusive = inclusive;
54  options.omitUsesFromAbove = omitUsesFromAbove;
55  options.omitBlockArguments = omitBlockArguments;
56  return (innerMatcher.match(rootOp) &&
57  matches(rootOp, backwardSlice, options, maxDepth));
58  }
59 
60 private:
61  bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
62  BackwardSliceOptions &options, int64_t maxDepth);
63 
64 private:
65  // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
66  // to determine whether we want to traverse the IR or not. For example, we
67  // want to explore the IR only if the top-level operation name is
68  // `"arith.addf"`.
69  Matcher innerMatcher;
70  // `maxDepth` specifies the maximum depth that the matcher can traverse the
71  // IR. For example, if `maxDepth` is 2, the matcher will explore the defining
72  // operations of the top-level op up to 2 levels.
73  int64_t maxDepth;
74  bool inclusive;
75  bool omitBlockArguments;
76  bool omitUsesFromAbove;
77 };
78 
79 template <typename Matcher>
80 bool BackwardSliceMatcher<Matcher>::matches(
81  Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
82  BackwardSliceOptions &options, int64_t maxDepth) {
83  backwardSlice.clear();
85  // Initializing the root op with a depth of 0
86  opDepths[rootOp] = 0;
87  options.filter = [&](Operation *subOp) {
88  // If the subOp hasn't been recorded in opDepths, it is deeper than
89  // maxDepth.
90  if (!opDepths.contains(subOp))
91  return false;
92  // Examine subOp's operands to compute depths of their defining operations.
93  for (auto operand : subOp->getOperands()) {
94  int64_t newDepth = opDepths[subOp] + 1;
95  // If the newDepth is greater than maxDepth, further computation can be
96  // skipped.
97  if (newDepth > maxDepth)
98  continue;
99 
100  if (auto definingOp = operand.getDefiningOp()) {
101  // Registers the minimum depth
102  if (!opDepths.contains(definingOp) || newDepth < opDepths[definingOp])
103  opDepths[definingOp] = newDepth;
104  } else {
105  auto blockArgument = cast<BlockArgument>(operand);
106  Operation *parentOp = blockArgument.getOwner()->getParentOp();
107  if (!parentOp)
108  continue;
109 
110  if (!opDepths.contains(parentOp) || newDepth < opDepths[parentOp])
111  opDepths[parentOp] = newDepth;
112  }
113  }
114  return true;
115  };
116  LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
117  assert(result.succeeded() && "expected backward slice to succeed");
118  (void)result;
119  return options.inclusive ? backwardSlice.size() > 1
120  : backwardSlice.size() >= 1;
121 }
122 
123 /// Computes the backward-slice of all transitive defs reachable from `rootOp`,
124 /// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
125 template <typename BaseMatcher, typename Filter>
127 public:
128  PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
129  bool inclusive, bool omitBlockArguments,
130  bool omitUsesFromAbove)
131  : innerMatcher(std::move(innerMatcher)),
132  filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
133  omitBlockArguments(omitBlockArguments),
134  omitUsesFromAbove(omitUsesFromAbove) {}
135 
136  bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
137  backwardSlice.clear();
139  options.inclusive = inclusive;
140  options.omitUsesFromAbove = omitUsesFromAbove;
141  options.omitBlockArguments = omitBlockArguments;
142  if (innerMatcher.match(rootOp)) {
143  options.filter = [&](Operation *subOp) {
144  return !filterMatcher.match(subOp);
145  };
146  LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
147  assert(result.succeeded() && "expected backward slice to succeed");
148  (void)result;
149  return options.inclusive ? backwardSlice.size() > 1
150  : backwardSlice.size() >= 1;
151  }
152  return false;
153  }
154 
155 private:
156  BaseMatcher innerMatcher;
157  Filter filterMatcher;
158  bool inclusive;
159  bool omitBlockArguments;
160  bool omitUsesFromAbove;
161 };
162 
163 /// Computes the forward-slice of all users reachable from `rootOp`,
164 /// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
165 template <typename BaseMatcher, typename Filter>
167 public:
168  PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
169  bool inclusive)
170  : innerMatcher(std::move(innerMatcher)),
171  filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
172 
173  bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
174  forwardSlice.clear();
176  options.inclusive = inclusive;
177  if (innerMatcher.match(rootOp)) {
178  options.filter = [&](Operation *subOp) {
179  return !filterMatcher.match(subOp);
180  };
181  getForwardSlice(rootOp, &forwardSlice, options);
182  return options.inclusive ? forwardSlice.size() > 1
183  : forwardSlice.size() >= 1;
184  }
185  return false;
186  }
187 
188 private:
189  BaseMatcher innerMatcher;
190  Filter filterMatcher;
191  bool inclusive;
192 };
193 
194 /// Matches transitive defs of a top-level operation up to N levels.
195 template <typename Matcher>
196 inline BackwardSliceMatcher<Matcher>
197 m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
198  bool omitBlockArguments, bool omitUsesFromAbove) {
199  assert(maxDepth >= 0 && "maxDepth must be non-negative");
200  return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth,
201  inclusive, omitBlockArguments,
202  omitUsesFromAbove);
203 }
204 
205 /// Matches all transitive defs of a top-level operation up to N levels.
206 template <typename Matcher>
208  int64_t maxDepth) {
209  assert(maxDepth >= 0 && "maxDepth must be non-negative");
210  return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth, true,
211  false, false);
212 }
213 
214 /// Matches all transitive defs of a top-level operation and stops where
215 /// `filterMatcher` rejects.
216 template <typename BaseMatcher, typename Filter>
217 inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
218 m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
219  bool inclusive, bool omitBlockArguments,
220  bool omitUsesFromAbove) {
222  std::move(innerMatcher), std::move(filterMatcher), inclusive,
223  omitBlockArguments, omitUsesFromAbove);
224 }
225 
226 /// Matches all users of a top-level operation and stops where
227 /// `filterMatcher` rejects.
228 template <typename BaseMatcher, typename Filter>
229 inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
230 m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
231  bool inclusive) {
233  std::move(innerMatcher), std::move(filterMatcher), inclusive);
234 }
235 
236 } // namespace mlir::query::matcher
237 
238 #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
static llvm::ManagedStatic< PassManagerOptions > options
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive, bool omitBlockArguments, bool omitUsesFromAbove)
Definition: SliceMatchers.h:45
bool match(Operation *rootOp, SetVector< Operation * > &backwardSlice)
Definition: SliceMatchers.h:51
Computes the backward-slice of all transitive defs reachable from rootOp, if innerMatcher matches.
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher, bool inclusive, bool omitBlockArguments, bool omitUsesFromAbove)
bool match(Operation *rootOp, SetVector< Operation * > &backwardSlice)
Computes the forward-slice of all users reachable from rootOp, if innerMatcher matches.
PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher, bool inclusive)
bool match(Operation *rootOp, SetVector< Operation * > &forwardSlice)
Computes the backward-slice of all transitive defs reachable from rootOp, if innerMatcher matches.
Definition: ErrorBuilder.h:20
PredicateForwardSliceMatcher< BaseMatcher, Filter > m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher, bool inclusive)
Matches all users of a top-level operation and stops where filterMatcher rejects.
PredicateBackwardSliceMatcher< BaseMatcher, Filter > m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher, bool inclusive, bool omitBlockArguments, bool omitUsesFromAbove)
Matches all transitive defs of a top-level operation and stops where filterMatcher rejects.
BackwardSliceMatcher< Matcher > m_GetAllDefinitions(Matcher innerMatcher, int64_t maxDepth)
Matches all transitive defs of a top-level operation up to N levels.
BackwardSliceMatcher< Matcher > m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive, bool omitBlockArguments, bool omitUsesFromAbove)
Matches transitive defs of a top-level operation up to N levels.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.