MLIR 22.0.0git
MatchersInternal.h
Go to the documentation of this file.
1//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements the base layer of the matcher framework.
10//
11// Matchers are methods that return a Matcher which provides a
12// `match(...)` method whose parameters define the context of the match.
13// Support includes simple (unary) matchers as well as matcher combinators
14// (anyOf, allOf, etc.)
15//
16// This file contains the wrapper classes needed to construct matchers for
17// mlir-query.
18//
19//===----------------------------------------------------------------------===//
20
21#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
22#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
23
24#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/IntrusiveRefCntPtr.h"
26
27namespace mlir::query::matcher {
28class DynMatcher;
29namespace internal {
30
31bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
32 ArrayRef<DynMatcher> innerMatchers);
33bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
34 ArrayRef<DynMatcher> innerMatchers);
35
36} // namespace internal
37
38// Defaults to false if T has no match() method with the signature:
39// match(Operation* op).
40template <typename T, typename = void>
41struct has_simple_match : std::false_type {};
42
43// Specialized type trait that evaluates to true if T has a match() method
44// with the signature: match(Operation* op).
45template <typename T>
46struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
47 std::declval<Operation *>()))>>
48 : std::true_type {};
49
50// Defaults to false if T has no match() method with the signature:
51// match(Operation* op, SetVector<Operation*>&).
52template <typename T, typename = void>
53struct has_bound_match : std::false_type {};
54
55// Specialized type trait that evaluates to true if T has a match() method
56// with the signature: match(Operation* op, SetVector<Operation*>&).
57template <typename T>
58struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
59 std::declval<Operation *>(),
60 std::declval<SetVector<Operation *> &>()))>>
61 : std::true_type {};
62
63// Generic interface for matchers on an MLIR operation.
65 : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
66public:
67 virtual ~MatcherInterface() = default;
68
69 virtual bool match(Operation *op) = 0;
70 virtual bool match(Operation *op, SetVector<Operation *> &matchedOps) = 0;
71};
72
73// MatcherFnImpl takes a matcher function object and implements
74// MatcherInterface.
75template <typename MatcherFn>
77public:
78 MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
79
80 bool match(Operation *op) override {
82 return matcherFn.match(op);
83 return false;
84 }
85
86 bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
88 return matcherFn.match(op, matchedOps);
89 return false;
90 }
91
92private:
93 MatcherFn matcherFn;
94};
95
96// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
97// match the given operation.
99 SetVector<Operation *> *matchedOps,
100 ArrayRef<DynMatcher> innerMatchers);
101
102template <VariadicOperatorFunction Func>
104public:
105 VariadicMatcher(std::vector<DynMatcher> matchers)
106 : matchers(std::move(matchers)) {}
107
108 bool match(Operation *op) override { return Func(op, nullptr, matchers); }
109 bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
110 return Func(op, &matchedOps, matchers);
111 }
112
113private:
114 std::vector<DynMatcher> matchers;
115};
116
117// Matcher wraps a MatcherInterface implementation and provides match()
118// methods that redirect calls to the underlying implementation.
120public:
121 // Takes ownership of the provided implementation pointer.
123 : implementation(implementation) {}
124
125 // Construct from a variadic function.
127 // Matches operations for which all provided matchers match.
129 // Matches operations for which at least one of the provided matchers
130 // matches.
132 };
133
134 static std::unique_ptr<DynMatcher>
136 std::vector<DynMatcher> innerMatchers) {
137 switch (Op) {
138 case AllOf:
139 return std::make_unique<DynMatcher>(
141 std::move(innerMatchers)));
142 case AnyOf:
143 return std::make_unique<DynMatcher>(
145 std::move(innerMatchers)));
146 }
147 llvm_unreachable("Invalid Op value.");
148 }
149
150 template <typename MatcherFn>
151 static std::unique_ptr<DynMatcher>
152 constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
153 auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
154 return std::make_unique<DynMatcher>(impl.release());
155 }
156
157 bool match(Operation *op) const { return implementation->match(op); }
158 bool match(Operation *op, SetVector<Operation *> &matchedOps) const {
159 return implementation->match(op, matchedOps);
160 }
161
162 void setFunctionName(StringRef name) { functionName = name.str(); }
163 bool hasFunctionName() const { return !functionName.empty(); }
164 StringRef getFunctionName() const { return functionName; }
165
166private:
167 llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
168 std::string functionName;
169};
170
171// VariadicOperatorMatcher related types.
172template <typename... Ps>
174public:
176 : varOp(varOp), params(std::forward<Ps>(params)...) {}
177
178 operator std::unique_ptr<DynMatcher>() const & {
180 varOp, getMatchers(std::index_sequence_for<Ps...>()));
181 }
182
183 operator std::unique_ptr<DynMatcher>() && {
185 varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
186 }
187
188private:
189 // Helper method to unpack the tuple into a vector.
190 template <std::size_t... Is>
191 std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
192 return {DynMatcher(std::get<Is>(params))...};
193 }
194
195 template <std::size_t... Is>
196 std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
197 return {DynMatcher(std::get<Is>(std::move(params)))...};
198 }
199
201 std::tuple<Ps...> params;
202};
203
204// Overloaded function object to generate VariadicOperatorMatcher objects from
205// arbitrary matchers.
206template <unsigned MinCount, unsigned MaxCount>
209
210 template <typename... Ms>
211 VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
212 static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
213 "invalid number of parameters for variadic matcher");
214 return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
215 }
216};
217
218namespace internal {
219const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
223} // namespace internal
224} // namespace mlir::query::matcher
225
226#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
DynMatcher(MatcherInterface *implementation)
bool match(Operation *op) const
static std::unique_ptr< DynMatcher > constructVariadic(VariadicOperator Op, std::vector< DynMatcher > innerMatchers)
static std::unique_ptr< DynMatcher > constructDynMatcherFromMatcherFn(MatcherFn &matcherFn)
bool match(Operation *op, SetVector< Operation * > &matchedOps) const
bool match(Operation *op, SetVector< Operation * > &matchedOps) override
bool match(Operation *op) override
virtual bool match(Operation *op, SetVector< Operation * > &matchedOps)=0
virtual bool match(Operation *op)=0
bool match(Operation *op, SetVector< Operation * > &matchedOps) override
bool match(Operation *op) override
VariadicMatcher(std::vector< DynMatcher > matchers)
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
bool anyOfVariadicOperator(Operation *op, SetVector< Operation * > *matchedOps, ArrayRef< DynMatcher > innerMatchers)
const VariadicOperatorMatcherFunc< 1, std::numeric_limits< unsigned >::max()> anyOf
bool allOfVariadicOperator(Operation *op, SetVector< Operation * > *matchedOps, ArrayRef< DynMatcher > innerMatchers)
const VariadicOperatorMatcherFunc< 1, std::numeric_limits< unsigned >::max()> allOf
Computes the backward-slice of all transitive defs reachable from rootOp, if innerMatcher matches.
bool(*)(Operation *op, SetVector< Operation * > *matchedOps, ArrayRef< DynMatcher > innerMatchers) VariadicOperatorFunction
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
VariadicOperatorMatcher< Ms... > operator()(Ms &&...Ps) const