MLIR 23.0.0git
MatchersInternal.h
Go to the documentation of this file.
1//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements the base layer of the matcher framework.
10//
11// Matchers are methods that return a Matcher which provides a
12// `match(...)` method whose parameters define the context of the match.
13// Support includes simple (unary) matchers as well as matcher combinators
14// (anyOf, allOf, etc.)
15//
16// This file contains the wrapper classes needed to construct matchers for
17// mlir-query.
18//
19//===----------------------------------------------------------------------===//
20
21#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
22#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
23
24#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/IntrusiveRefCntPtr.h"
26
27namespace mlir::query::matcher {
28class DynMatcher;
29namespace internal {
30
31bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
32 ArrayRef<DynMatcher> innerMatchers);
33bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
34 ArrayRef<DynMatcher> innerMatchers);
35
36} // namespace internal
37
38// Defaults to false if T has no match() method with the signature:
39// match(Operation* op).
40template <typename T, typename = void>
41struct has_simple_match : std::false_type {};
42
43// Specialized type trait that evaluates to true if T has a match() method
44// with the signature: match(Operation* op).
45template <typename T>
46struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
47 std::declval<Operation *>()))>>
48 : std::true_type {};
49
50// Defaults to false if T has no match() method with the signature:
51// match(Operation* op, SetVector<Operation*>&).
52template <typename T, typename = void>
53struct has_bound_match : std::false_type {};
54
55// Specialized type trait that evaluates to true if T has a match() method
56// with the signature: match(Operation* op, SetVector<Operation*>&).
57template <typename T>
58struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
59 std::declval<Operation *>(),
60 std::declval<SetVector<Operation *> &>()))>>
61 : std::true_type {};
62
63// Generic interface for matchers on an MLIR operation.
65 : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
66public:
67 virtual ~MatcherInterface() = default;
68
69 virtual bool match(Operation *op) = 0;
70 virtual bool match(Operation *op, SetVector<Operation *> &matchedOps) = 0;
71};
72
73// MatcherFnImpl takes a matcher function object and implements
74// MatcherInterface.
75template <typename MatcherFn>
77public:
78 MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
79
80 bool match(Operation *op) override {
82 return matcherFn.match(op);
83 return false;
84 }
85
86 bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
88 return matcherFn.match(op, matchedOps);
89 return false;
90 }
91
92private:
93 MatcherFn matcherFn;
94};
95
96// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
97// match the given operation.
99 SetVector<Operation *> *matchedOps,
100 ArrayRef<DynMatcher> innerMatchers);
101
102template <VariadicOperatorFunction Func>
104public:
105 VariadicMatcher(std::vector<DynMatcher> matchers);
107
108 bool match(Operation *op) override;
109 bool match(Operation *op, SetVector<Operation *> &matchedOps) override;
110
111private:
112 std::vector<DynMatcher> matchers;
113};
114
115// Matcher wraps a MatcherInterface implementation and provides match()
116// methods that redirect calls to the underlying implementation.
118public:
119 // Takes ownership of the provided implementation pointer.
121 : implementation(implementation) {}
122
123 // Construct from a variadic function.
125 // Matches operations for which all provided matchers match.
127 // Matches operations for which at least one of the provided matchers
128 // matches.
130 };
131
132 static std::unique_ptr<DynMatcher>
134 std::vector<DynMatcher> innerMatchers) {
135 switch (Op) {
136 case AllOf:
137 return std::make_unique<DynMatcher>(
139 std::move(innerMatchers)));
140 case AnyOf:
141 return std::make_unique<DynMatcher>(
143 std::move(innerMatchers)));
144 }
145 llvm_unreachable("Invalid Op value.");
146 }
147
148 template <typename MatcherFn>
149 static std::unique_ptr<DynMatcher>
150 constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
151 auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
152 return std::make_unique<DynMatcher>(impl.release());
153 }
154
155 bool match(Operation *op) const { return implementation->match(op); }
156 bool match(Operation *op, SetVector<Operation *> &matchedOps) const {
157 return implementation->match(op, matchedOps);
158 }
159
160 void setFunctionName(StringRef name) { functionName = name.str(); }
161 bool hasFunctionName() const { return !functionName.empty(); }
162 StringRef getFunctionName() const { return functionName; }
163
164private:
165 llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
166 std::string functionName;
167};
168
169// Implementation of VariadicMatcher functions after DynMatcher is defined
170template <VariadicOperatorFunction Func>
171VariadicMatcher<Func>::VariadicMatcher(std::vector<DynMatcher> matchers)
172 : matchers(std::move(matchers)) {}
173
174template <VariadicOperatorFunction Func>
176
177template <VariadicOperatorFunction Func>
179 return Func(op, nullptr, matchers);
180}
181
182template <VariadicOperatorFunction Func>
184 SetVector<Operation *> &matchedOps) {
185 return Func(op, &matchedOps, matchers);
186}
187
188// VariadicOperatorMatcher related types.
189template <typename... Ps>
191public:
193 : varOp(varOp), params(std::forward<Ps>(params)...) {}
194
195 operator std::unique_ptr<DynMatcher>() const & {
197 varOp, getMatchers(std::index_sequence_for<Ps...>()));
198 }
199
200 operator std::unique_ptr<DynMatcher>() && {
202 varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
203 }
204
205private:
206 // Helper method to unpack the tuple into a vector.
207 template <std::size_t... Is>
208 std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
209 return {DynMatcher(std::get<Is>(params))...};
210 }
211
212 template <std::size_t... Is>
213 std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
214 return {DynMatcher(std::get<Is>(std::move(params)))...};
215 }
216
218 std::tuple<Ps...> params;
219};
220
221// Overloaded function object to generate VariadicOperatorMatcher objects from
222// arbitrary matchers.
223template <unsigned MinCount, unsigned MaxCount>
226
227 template <typename... Ms>
228 VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
229 static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
230 "invalid number of parameters for variadic matcher");
231 return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
232 }
233};
234
235namespace internal {
236const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
240} // namespace internal
241} // namespace mlir::query::matcher
242
243#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
DynMatcher(MatcherInterface *implementation)
bool match(Operation *op) const
static std::unique_ptr< DynMatcher > constructVariadic(VariadicOperator Op, std::vector< DynMatcher > innerMatchers)
static std::unique_ptr< DynMatcher > constructDynMatcherFromMatcherFn(MatcherFn &matcherFn)
bool match(Operation *op, SetVector< Operation * > &matchedOps) const
bool match(Operation *op, SetVector< Operation * > &matchedOps) override
bool match(Operation *op) override
virtual bool match(Operation *op, SetVector< Operation * > &matchedOps)=0
virtual bool match(Operation *op)=0
bool match(Operation *op) override
VariadicMatcher(std::vector< DynMatcher > matchers)
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
bool anyOfVariadicOperator(Operation *op, SetVector< Operation * > *matchedOps, ArrayRef< DynMatcher > innerMatchers)
const VariadicOperatorMatcherFunc< 1, std::numeric_limits< unsigned >::max()> anyOf
bool allOfVariadicOperator(Operation *op, SetVector< Operation * > *matchedOps, ArrayRef< DynMatcher > innerMatchers)
const VariadicOperatorMatcherFunc< 1, std::numeric_limits< unsigned >::max()> allOf
Computes the backward-slice of all transitive defs reachable from rootOp, if innerMatcher matches.
bool(*)(Operation *op, SetVector< Operation * > *matchedOps, ArrayRef< DynMatcher > innerMatchers) VariadicOperatorFunction
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
VariadicOperatorMatcher< Ms... > operator()(Ms &&...Ps) const