MLIR  21.0.0git
MatchersInternal.h
Go to the documentation of this file.
1 //===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements the base layer of the matcher framework.
10 //
11 // Matchers are methods that return a Matcher which provides a
12 // `match(...)` method whose parameters define the context of the match.
13 // Support includes simple (unary) matchers as well as matcher combinators
14 // (anyOf, allOf, etc.)
15 //
16 // This file contains the wrapper classes needed to construct matchers for
17 // mlir-query.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
22 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
23 
24 #include "mlir/IR/Matchers.h"
25 #include "llvm/ADT/IntrusiveRefCntPtr.h"
26 
27 namespace mlir::query::matcher {
28 class DynMatcher;
29 namespace internal {
30 
31 bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
32  ArrayRef<DynMatcher> innerMatchers);
33 bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
34  ArrayRef<DynMatcher> innerMatchers);
35 
36 } // namespace internal
37 
38 // Defaults to false if T has no match() method with the signature:
39 // match(Operation* op).
40 template <typename T, typename = void>
41 struct has_simple_match : std::false_type {};
42 
43 // Specialized type trait that evaluates to true if T has a match() method
44 // with the signature: match(Operation* op).
45 template <typename T>
46 struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
47  std::declval<Operation *>()))>>
48  : std::true_type {};
49 
50 // Defaults to false if T has no match() method with the signature:
51 // match(Operation* op, SetVector<Operation*>&).
52 template <typename T, typename = void>
53 struct has_bound_match : std::false_type {};
54 
55 // Specialized type trait that evaluates to true if T has a match() method
56 // with the signature: match(Operation* op, SetVector<Operation*>&).
57 template <typename T>
58 struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
59  std::declval<Operation *>(),
60  std::declval<SetVector<Operation *> &>()))>>
61  : std::true_type {};
62 
63 // Generic interface for matchers on an MLIR operation.
65  : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
66 public:
67  virtual ~MatcherInterface() = default;
68 
69  virtual bool match(Operation *op) = 0;
70  virtual bool match(Operation *op, SetVector<Operation *> &matchedOps) = 0;
71 };
72 
73 // MatcherFnImpl takes a matcher function object and implements
74 // MatcherInterface.
75 template <typename MatcherFn>
77 public:
78  MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
79 
80  bool match(Operation *op) override {
82  return matcherFn.match(op);
83  return false;
84  }
85 
86  bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
88  return matcherFn.match(op, matchedOps);
89  return false;
90  }
91 
92 private:
93  MatcherFn matcherFn;
94 };
95 
96 // VariadicMatcher takes a vector of Matchers and returns true if any Matchers
97 // match the given operation.
98 using VariadicOperatorFunction = bool (*)(Operation *op,
99  SetVector<Operation *> *matchedOps,
100  ArrayRef<DynMatcher> innerMatchers);
101 
102 template <VariadicOperatorFunction Func>
104 public:
105  VariadicMatcher(std::vector<DynMatcher> matchers)
106  : matchers(std::move(matchers)) {}
107 
108  bool match(Operation *op) override { return Func(op, nullptr, matchers); }
109  bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
110  return Func(op, &matchedOps, matchers);
111  }
112 
113 private:
114  std::vector<DynMatcher> matchers;
115 };
116 
117 // Matcher wraps a MatcherInterface implementation and provides match()
118 // methods that redirect calls to the underlying implementation.
119 class DynMatcher {
120 public:
121  // Takes ownership of the provided implementation pointer.
122  DynMatcher(MatcherInterface *implementation)
123  : implementation(implementation) {}
124 
125  // Construct from a variadic function.
127  // Matches operations for which all provided matchers match.
129  // Matches operations for which at least one of the provided matchers
130  // matches.
131  AnyOf
132  };
133 
134  static std::unique_ptr<DynMatcher>
136  std::vector<DynMatcher> innerMatchers) {
137  switch (Op) {
138  case AllOf:
139  return std::make_unique<DynMatcher>(
141  std::move(innerMatchers)));
142  case AnyOf:
143  return std::make_unique<DynMatcher>(
145  std::move(innerMatchers)));
146  }
147  llvm_unreachable("Invalid Op value.");
148  }
149 
150  template <typename MatcherFn>
151  static std::unique_ptr<DynMatcher>
152  constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
153  auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
154  return std::make_unique<DynMatcher>(impl.release());
155  }
156 
157  bool match(Operation *op) const { return implementation->match(op); }
158  bool match(Operation *op, SetVector<Operation *> &matchedOps) const {
159  return implementation->match(op, matchedOps);
160  }
161 
162  void setFunctionName(StringRef name) { functionName = name.str(); }
163  bool hasFunctionName() const { return !functionName.empty(); }
164  StringRef getFunctionName() const { return functionName; }
165 
166 private:
167  llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
168  std::string functionName;
169 };
170 
171 // VariadicOperatorMatcher related types.
172 template <typename... Ps>
174 public:
176  : varOp(varOp), params(std::forward<Ps>(params)...) {}
177 
178  operator std::unique_ptr<DynMatcher>() const & {
180  varOp, getMatchers(std::index_sequence_for<Ps...>()));
181  }
182 
183  operator std::unique_ptr<DynMatcher>() && {
185  varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
186  }
187 
188 private:
189  // Helper method to unpack the tuple into a vector.
190  template <std::size_t... Is>
191  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
192  return {DynMatcher(std::get<Is>(params))...};
193  }
194 
195  template <std::size_t... Is>
196  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
197  return {DynMatcher(std::get<Is>(std::move(params)))...};
198  }
199 
200  const DynMatcher::VariadicOperator varOp;
201  std::tuple<Ps...> params;
202 };
203 
204 // Overloaded function object to generate VariadicOperatorMatcher objects from
205 // arbitrary matchers.
206 template <unsigned MinCount, unsigned MaxCount>
209 
210  template <typename... Ms>
211  VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
212  static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
213  "invalid number of parameters for variadic matcher");
214  return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
215  }
216 };
217 
218 namespace internal {
223 } // namespace internal
224 } // namespace mlir::query::matcher
225 
226 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DynMatcher(MatcherInterface *implementation)
bool match(Operation *op) const
static std::unique_ptr< DynMatcher > constructDynMatcherFromMatcherFn(MatcherFn &matcherFn)
static std::unique_ptr< DynMatcher > constructVariadic(VariadicOperator Op, std::vector< DynMatcher > innerMatchers)
bool match(Operation *op, SetVector< Operation * > &matchedOps) const
void setFunctionName(StringRef name)
bool match(Operation *op, SetVector< Operation * > &matchedOps) override
bool match(Operation *op) override
virtual bool match(Operation *op, SetVector< Operation * > &matchedOps)=0
virtual bool match(Operation *op)=0
bool match(Operation *op, SetVector< Operation * > &matchedOps) override
bool match(Operation *op) override
VariadicMatcher(std::vector< DynMatcher > matchers)
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
bool anyOfVariadicOperator(Operation *op, SetVector< Operation * > *matchedOps, ArrayRef< DynMatcher > innerMatchers)
const VariadicOperatorMatcherFunc< 1, std::numeric_limits< unsigned >::max()> anyOf
bool allOfVariadicOperator(Operation *op, SetVector< Operation * > *matchedOps, ArrayRef< DynMatcher > innerMatchers)
const VariadicOperatorMatcherFunc< 1, std::numeric_limits< unsigned >::max()> allOf
Computes the backward-slice of all transitive defs reachable from rootOp, if innerMatcher matches.
Definition: ErrorBuilder.h:20
bool(*)(Operation *op, SetVector< Operation * > *matchedOps, ArrayRef< DynMatcher > innerMatchers) VariadicOperatorFunction
VariadicOperatorMatcher< Ms... > operator()(Ms &&...Ps) const