MLIR 22.0.0git
SliceMatchers.h
Go to the documentation of this file.
1//===- SliceMatchers.h - Matchers for slicing analysis ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines slicing-analysis matchers that extend and abstract the
10// core implementations from `SliceAnalysis.h`.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
15#define MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
16
18#include "mlir/IR/Operation.h"
19
20/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
21/// if `innerMatcher` matches. The traversal stops once the desired depth level
22/// is reached.
23///
24/// Example: starting from node 9, assuming the matcher
25/// computes the slice for the first two depth levels:
26/// ============================
27/// 1 2 3 4
28/// |_______| |______|
29/// | | |
30/// | 5 6
31/// |___|_____________|
32/// | |
33/// 7 8
34/// |_______________|
35/// |
36/// 9
37///
38/// Assuming all local orders match the numbering order:
39/// {1, 5, 6, 7, 8, 9}
40namespace mlir::query::matcher {
41
42template <typename Matcher>
44public:
45 BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
46 bool omitBlockArguments, bool omitUsesFromAbove)
47 : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
48 inclusive(inclusive), omitBlockArguments(omitBlockArguments),
49 omitUsesFromAbove(omitUsesFromAbove) {}
50
51 bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
53 options.inclusive = inclusive;
54 options.omitUsesFromAbove = omitUsesFromAbove;
55 options.omitBlockArguments = omitBlockArguments;
56 return (innerMatcher.match(rootOp) &&
57 matches(rootOp, backwardSlice, options, maxDepth));
58 }
59
60private:
61 bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
63
64private:
65 // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
66 // to determine whether we want to traverse the IR or not. For example, we
67 // want to explore the IR only if the top-level operation name is
68 // `"arith.addf"`.
69 Matcher innerMatcher;
70 // `maxDepth` specifies the maximum depth that the matcher can traverse the
71 // IR. For example, if `maxDepth` is 2, the matcher will explore the defining
72 // operations of the top-level op up to 2 levels.
73 int64_t maxDepth;
74 bool inclusive;
75 bool omitBlockArguments;
76 bool omitUsesFromAbove;
77};
78
79template <typename Matcher>
80bool BackwardSliceMatcher<Matcher>::matches(
81 Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
83 backwardSlice.clear();
85 // Initializing the root op with a depth of 0
86 opDepths[rootOp] = 0;
87 options.filter = [&](Operation *subOp) {
88 // If the subOp hasn't been recorded in opDepths, it is deeper than
89 // maxDepth.
90 if (!opDepths.contains(subOp))
91 return false;
92 // Examine subOp's operands to compute depths of their defining operations.
93 for (auto operand : subOp->getOperands()) {
94 int64_t newDepth = opDepths[subOp] + 1;
95 // If the newDepth is greater than maxDepth, further computation can be
96 // skipped.
97 if (newDepth > maxDepth)
98 continue;
99
100 if (auto definingOp = operand.getDefiningOp()) {
101 // Registers the minimum depth
102 if (!opDepths.contains(definingOp) || newDepth < opDepths[definingOp])
103 opDepths[definingOp] = newDepth;
104 } else {
105 auto blockArgument = cast<BlockArgument>(operand);
106 Operation *parentOp = blockArgument.getOwner()->getParentOp();
107 if (!parentOp)
108 continue;
109
110 if (!opDepths.contains(parentOp) || newDepth < opDepths[parentOp])
111 opDepths[parentOp] = newDepth;
112 }
113 }
114 return true;
115 };
116 LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
117 assert(result.succeeded() && "expected backward slice to succeed");
118 (void)result;
119 return options.inclusive ? backwardSlice.size() > 1
120 : backwardSlice.size() >= 1;
121}
122
123/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
124/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
125template <typename BaseMatcher, typename Filter>
127public:
128 PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
129 bool inclusive, bool omitBlockArguments,
130 bool omitUsesFromAbove)
131 : innerMatcher(std::move(innerMatcher)),
132 filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
133 omitBlockArguments(omitBlockArguments),
134 omitUsesFromAbove(omitUsesFromAbove) {}
135
136 bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
137 backwardSlice.clear();
139 options.inclusive = inclusive;
140 options.omitUsesFromAbove = omitUsesFromAbove;
141 options.omitBlockArguments = omitBlockArguments;
142 if (innerMatcher.match(rootOp)) {
143 options.filter = [&](Operation *subOp) {
144 return !filterMatcher.match(subOp);
145 };
146 LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
147 assert(result.succeeded() && "expected backward slice to succeed");
148 (void)result;
149 return options.inclusive ? backwardSlice.size() > 1
150 : backwardSlice.size() >= 1;
151 }
152 return false;
153 }
154
155private:
156 BaseMatcher innerMatcher;
157 Filter filterMatcher;
158 bool inclusive;
159 bool omitBlockArguments;
160 bool omitUsesFromAbove;
161};
162
163/// Computes the forward-slice of all users reachable from `rootOp`,
164/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
165template <typename BaseMatcher, typename Filter>
167public:
168 PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
169 bool inclusive)
170 : innerMatcher(std::move(innerMatcher)),
171 filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
172
173 bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
174 forwardSlice.clear();
176 options.inclusive = inclusive;
177 if (innerMatcher.match(rootOp)) {
178 options.filter = [&](Operation *subOp) {
179 return !filterMatcher.match(subOp);
180 };
181 getForwardSlice(rootOp, &forwardSlice, options);
182 return options.inclusive ? forwardSlice.size() > 1
183 : forwardSlice.size() >= 1;
184 }
185 return false;
186 }
187
188private:
189 BaseMatcher innerMatcher;
190 Filter filterMatcher;
191 bool inclusive;
192};
193
194/// Matches transitive defs of a top-level operation up to N levels.
195template <typename Matcher>
196inline BackwardSliceMatcher<Matcher>
197m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
198 bool omitBlockArguments, bool omitUsesFromAbove) {
199 assert(maxDepth >= 0 && "maxDepth must be non-negative");
200 return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth,
201 inclusive, omitBlockArguments,
202 omitUsesFromAbove);
203}
204
205/// Matches all transitive defs of a top-level operation up to N levels.
206template <typename Matcher>
208 int64_t maxDepth) {
209 assert(maxDepth >= 0 && "maxDepth must be non-negative");
210 return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth, true,
211 false, false);
212}
213
214/// Matches all transitive defs of a top-level operation and stops where
215/// `filterMatcher` rejects.
216template <typename BaseMatcher, typename Filter>
217inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
218m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
219 bool inclusive, bool omitBlockArguments,
220 bool omitUsesFromAbove) {
222 std::move(innerMatcher), std::move(filterMatcher), inclusive,
223 omitBlockArguments, omitUsesFromAbove);
224}
225
226/// Matches all users of a top-level operation and stops where
227/// `filterMatcher` rejects.
228template <typename BaseMatcher, typename Filter>
229inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
230m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
231 bool inclusive) {
233 std::move(innerMatcher), std::move(filterMatcher), inclusive);
234}
235
236} // namespace mlir::query::matcher
237
238#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
static llvm::ManagedStatic< PassManagerOptions > options
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive, bool omitBlockArguments, bool omitUsesFromAbove)
bool match(Operation *rootOp, SetVector< Operation * > &backwardSlice)
Computes the backward-slice of all transitive defs reachable from rootOp, if innerMatcher matches.
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher, bool inclusive, bool omitBlockArguments, bool omitUsesFromAbove)
bool match(Operation *rootOp, SetVector< Operation * > &backwardSlice)
Computes the forward-slice of all users reachable from rootOp, if innerMatcher matches.
PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher, bool inclusive)
bool match(Operation *rootOp, SetVector< Operation * > &forwardSlice)
Computes the backward-slice of all transitive defs reachable from rootOp, if innerMatcher matches.
BackwardSliceMatcher< Matcher > m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive, bool omitBlockArguments, bool omitUsesFromAbove)
Matches transitive defs of a top-level operation up to N levels.
PredicateForwardSliceMatcher< BaseMatcher, Filter > m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher, bool inclusive)
Matches all users of a top-level operation and stops where filterMatcher rejects.
BackwardSliceMatcher< Matcher > m_GetAllDefinitions(Matcher innerMatcher, int64_t maxDepth)
Matches all transitive defs of a top-level operation up to N levels.
PredicateBackwardSliceMatcher< BaseMatcher, Filter > m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher, bool inclusive, bool omitBlockArguments, bool omitUsesFromAbove)
Matches all transitive defs of a top-level operation and stops where filterMatcher rejects.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
SliceOptions ForwardSliceOptions
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.