MLIR  14.0.0git
Matchers.h
Go to the documentation of this file.
1 //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11 // include/llvm/IR/PatternMatch.h.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_IR_MATCHERS_H
16 #define MLIR_IR_MATCHERS_H
17 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 
21 namespace mlir {
22 
23 namespace detail {
24 
25 /// The matcher that matches a certain kind of Attribute and binds the value
26 /// inside the Attribute.
27 template <
28  typename AttrClass,
29  // Require AttrClass to be a derived class from Attribute and get its
30  // value type
31  typename ValueType =
33  AttrClass>::type::ValueType,
34  // Require the ValueType is not void
35  typename = typename std::enable_if<!std::is_void<ValueType>::value>::type>
37  ValueType *bind_value;
38 
39  /// Creates a matcher instance that binds the value to bv if match succeeds.
40  attr_value_binder(ValueType *bv) : bind_value(bv) {}
41 
42  bool match(const Attribute &attr) {
43  if (auto intAttr = attr.dyn_cast<AttrClass>()) {
44  *bind_value = intAttr.getValue();
45  return true;
46  }
47  return false;
48  }
49 };
50 
51 /// Check to see if the specified operation is ConstantLike. This includes some
52 /// quick filters to avoid a semi-expensive test in the common case.
53 static bool isConstantLike(Operation *op) {
54  return op->getNumOperands() == 0 && op->getNumResults() == 1 &&
56 }
57 
58 /// The matcher that matches operations that have the `ConstantLike` trait.
60  bool match(Operation *op) { return isConstantLike(op); }
61 };
62 
63 /// The matcher that matches operations that have the `ConstantLike` trait, and
64 /// binds the folded attribute value.
65 template <typename AttrT>
67  AttrT *bind_value;
68 
69  /// Creates a matcher instance that binds the constant attribute value to
70  /// bind_value if match succeeds.
71  constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
72  /// Creates a matcher instance that doesn't bind if match succeeds.
73  constant_op_binder() : bind_value(nullptr) {}
74 
75  bool match(Operation *op) {
76  if (!isConstantLike(op))
77  return false;
78 
79  // Fold the constant to an attribute.
81  LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp);
82  (void)result;
83  assert(succeeded(result) && "expected ConstantLike op to be foldable");
84 
85  if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
86  if (bind_value)
87  *bind_value = attr;
88  return true;
89  }
90  return false;
91  }
92 };
93 
94 /// The matcher that matches a constant scalar / vector splat / tensor splat
95 /// integer operation and binds the constant integer value.
97  IntegerAttr::ValueType *bind_value;
98 
99  /// Creates a matcher instance that binds the value to bv if match succeeds.
100  constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
101 
102  bool match(Operation *op) {
103  Attribute attr;
104  if (!constant_op_binder<Attribute>(&attr).match(op))
105  return false;
106  auto type = op->getResult(0).getType();
107 
108  if (type.isa<IntegerType, IndexType>())
110  if (type.isa<VectorType, RankedTensorType>()) {
111  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
113  .match(splatAttr.getSplatValue<Attribute>());
114  }
115  }
116  return false;
117  }
118 };
119 
120 /// The matcher that matches a given target constant scalar / vector splat /
121 /// tensor splat integer value.
122 template <int64_t TargetValue>
124  bool match(Operation *op) {
125  APInt value;
126  return constant_int_op_binder(&value).match(op) && TargetValue == value;
127  }
128 };
129 
130 /// The matcher that matches anything except the given target constant scalar /
131 /// vector splat / tensor splat integer value.
132 template <int64_t TargetNotValue>
134  bool match(Operation *op) {
135  APInt value;
136  return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
137  }
138 };
139 
140 /// The matcher that matches a certain kind of op.
141 template <typename OpClass>
142 struct op_matcher {
143  bool match(Operation *op) { return isa<OpClass>(op); }
144 };
145 
146 /// Trait to check whether T provides a 'match' method with type
147 /// `OperationOrValue`.
148 template <typename T, typename OperationOrValue>
150  decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
151 
152 /// Statically switch to a Value matcher.
153 template <typename MatcherClass>
154 typename std::enable_if_t<
155  llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
156  Value>::value,
157  bool>
158 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
159  return matcher.match(op->getOperand(idx));
160 }
161 
162 /// Statically switch to an Operation matcher.
163 template <typename MatcherClass>
164 typename std::enable_if_t<
165  llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
166  Operation *>::value,
167  bool>
168 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
169  if (auto *defOp = op->getOperand(idx).getDefiningOp())
170  return matcher.match(defOp);
171  return false;
172 }
173 
174 /// Terminal matcher, always returns true.
176  bool match(Value op) const { return true; }
177 };
178 
179 /// Terminal matcher, always returns true.
182  AnyCapturedValueMatcher(Value *what) : what(what) {}
183  bool match(Value op) const {
184  *what = op;
185  return true;
186  }
187 };
188 
189 /// Binds to a specific value and matches it.
192  bool match(Value val) const { return val == value; }
194 };
195 
196 template <typename TupleT, class CallbackT, std::size_t... Is>
197 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
198  std::index_sequence<Is...>) {
199  (void)std::initializer_list<int>{
200  0,
201  (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
202  0)...};
203 }
204 
205 template <typename... Tys, typename CallbackT>
206 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
207  detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
208  std::make_index_sequence<sizeof...(Tys)>{});
209 }
210 
211 /// RecursivePatternMatcher that composes.
212 template <typename OpType, typename... OperandMatchers>
214  RecursivePatternMatcher(OperandMatchers... matchers)
215  : operandMatchers(matchers...) {}
216  bool match(Operation *op) {
217  if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
218  return false;
219  bool res = true;
220  enumerate(operandMatchers, [&](size_t index, auto &matcher) {
221  res &= matchOperandOrValueAtIndex(op, index, matcher);
222  });
223  return res;
224  }
225  std::tuple<OperandMatchers...> operandMatchers;
226 };
227 
228 } // namespace detail
229 
230 /// Matches a constant foldable operation.
233 }
234 
235 /// Matches a value from a constant foldable operation and writes the value to
236 /// bind_value.
237 template <typename AttrT>
240 }
241 
242 /// Matches a constant scalar / vector splat / tensor splat integer one.
245 }
246 
247 /// Matches the given OpClass.
248 template <typename OpClass>
251 }
252 
253 /// Matches a constant scalar / vector splat / tensor splat integer zero.
256 }
257 
258 /// Matches a constant scalar / vector splat / tensor splat integer that is any
259 /// non-zero value.
262 }
263 
264 /// Entry point for matching a pattern over a Value.
265 template <typename Pattern>
266 inline bool matchPattern(Value value, const Pattern &pattern) {
267  // TODO: handle other cases
268  if (auto *op = value.getDefiningOp())
269  return const_cast<Pattern &>(pattern).match(op);
270  return false;
271 }
272 
273 /// Entry point for matching a pattern over an Operation.
274 template <typename Pattern>
275 inline bool matchPattern(Operation *op, const Pattern &pattern) {
276  return const_cast<Pattern &>(pattern).match(op);
277 }
278 
279 /// Matches a constant holding a scalar/vector/tensor integer (splat) and
280 /// writes the integer value to bind_value.
282 m_ConstantInt(IntegerAttr::ValueType *bind_value) {
283  return detail::constant_int_op_binder(bind_value);
284 }
285 
286 template <typename OpType, typename... Matchers>
287 auto m_Op(Matchers... matchers) {
288  return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
289 }
290 
291 namespace matchers {
292 inline auto m_Any() { return detail::AnyValueMatcher(); }
293 inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); }
294 inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
295 } // namespace matchers
296 
297 } // namespace mlir
298 
299 #endif // MLIR_IR_MATCHERS_H
Include the generated interface declarations.
constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence< Is... >)
Definition: Matchers.h:197
The matcher that matches operations that have the ConstantLike trait, and binds the folded attribute ...
Definition: Matchers.h:66
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:282
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Binds to a specific value and matches it.
Definition: Matchers.h:190
detail::constant_int_not_value_matcher< 0 > m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value...
Definition: Matchers.h:260
Value getOperand(unsigned idx)
Definition: Operation.h:219
The matcher that matches a given target constant scalar / vector splat / tensor splat integer value...
Definition: Matchers.h:123
static bool isConstantLike(Operation *op)
Check to see if the specified operation is ConstantLike.
Definition: Matchers.h:53
unsigned getNumOperands()
Definition: Operation.h:215
attr_value_binder(ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:40
std::enable_if_t< llvm::is_detected< detail::has_operation_or_value_matcher_t, MatcherClass, Value >::value, bool > matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher)
Statically switch to a Value matcher.
Definition: Matchers.h:158
constant_int_op_binder(IntegerAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:100
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
detail::constant_int_value_matcher< 1 > m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:243
auto m_Any(Value *val)
Definition: Matchers.h:293
The matcher that matches anything except the given target constant scalar / vector splat / tensor spl...
Definition: Matchers.h:133
static constexpr const bool value
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
The matcher that matches operations that have the ConstantLike trait.
Definition: Matchers.h:59
RecursivePatternMatcher that composes.
Definition: Matchers.h:213
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
std::tuple< OperandMatchers... > operandMatchers
Definition: Matchers.h:225
bool match(Value op) const
Definition: Matchers.h:176
bool match(Operation *op)
Definition: Matchers.h:75
Attributes are known-constant values of operations.
Definition: Attributes.h:24
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:470
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
constant_op_binder()
Creates a matcher instance that doesn&#39;t bind if match succeeds.
Definition: Matchers.h:73
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:231
IntegerAttr::ValueType * bind_value
Definition: Matchers.h:97
decltype(std::declval< T >().match(std::declval< OperationOrValue >())) has_operation_or_value_matcher_t
Trait to check whether T provides a &#39;match&#39; method with type OperationOrValue.
Definition: Matchers.h:150
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition: Matchers.h:249
The matcher that matches a certain kind of op.
Definition: Matchers.h:142
Terminal matcher, always returns true.
Definition: Matchers.h:175
Type getType() const
Return the type of this value.
Definition: Value.h:117
RecursivePatternMatcher(OperandMatchers... matchers)
Definition: Matchers.h:214
The matcher that matches a constant scalar / vector splat / tensor splat integer operation and binds ...
Definition: Matchers.h:96
U dyn_cast() const
Definition: Attributes.h:117
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
bool match(Operation *op)
Definition: Matchers.h:143
constant_op_binder(AttrT *bind_value)
Creates a matcher instance that binds the constant attribute value to bind_value if match succeeds...
Definition: Matchers.h:71
bool match(Value val) const
Definition: Matchers.h:192
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
auto m_Val(Value v)
Definition: Matchers.h:294
detail::constant_int_value_matcher< 0 > m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:254
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:496
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
bool match(Operation *op)
Definition: Matchers.h:60
bool match(const Attribute &attr)
Definition: Matchers.h:42
This class provides the API for a sub-set of ops that are known to be constant-like.
The matcher that matches a certain kind of Attribute and binds the value inside the Attribute...
Definition: Matchers.h:36
Terminal matcher, always returns true.
Definition: Matchers.h:180