MLIR  19.0.0git
Matchers.h
Go to the documentation of this file.
1 //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11 // include/llvm/IR/PatternMatch.h.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_IR_MATCHERS_H
16 #define MLIR_IR_MATCHERS_H
17 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/OpDefinition.h"
21 
22 namespace mlir {
23 
24 namespace detail {
25 
26 /// The matcher that matches a certain kind of Attribute and binds the value
27 /// inside the Attribute.
28 template <
29  typename AttrClass,
30  // Require AttrClass to be a derived class from Attribute and get its
31  // value type
32  typename ValueType = typename std::enable_if_t<
33  std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType,
34  // Require the ValueType is not void
35  typename = std::enable_if_t<!std::is_void<ValueType>::value>>
37  ValueType *bind_value;
38 
39  /// Creates a matcher instance that binds the value to bv if match succeeds.
40  attr_value_binder(ValueType *bv) : bind_value(bv) {}
41 
42  bool match(Attribute attr) {
43  if (auto intAttr = llvm::dyn_cast<AttrClass>(attr)) {
44  *bind_value = intAttr.getValue();
45  return true;
46  }
47  return false;
48  }
49 };
50 
51 /// The matcher that matches operations that have the `ConstantLike` trait.
53  bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
54 };
55 
56 /// The matcher that matches operations that have the specified op name.
57 struct NameOpMatcher {
58  NameOpMatcher(StringRef name) : name(name) {}
59  bool match(Operation *op) { return op->getName().getStringRef() == name; }
60 
61  StringRef name;
62 };
63 
64 /// The matcher that matches operations that have the specified attribute name.
65 struct AttrOpMatcher {
67  bool match(Operation *op) { return op->hasAttr(attrName); }
68 
69  StringRef attrName;
70 };
71 
72 /// The matcher that matches operations that have the `ConstantLike` trait, and
73 /// binds the folded attribute value.
74 template <typename AttrT>
76  AttrT *bind_value;
77 
78  /// Creates a matcher instance that binds the constant attribute value to
79  /// bind_value if match succeeds.
81  /// Creates a matcher instance that doesn't bind if match succeeds.
83 
84  bool match(Operation *op) {
85  if (!op->hasTrait<OpTrait::ConstantLike>())
86  return false;
87 
88  // Fold the constant to an attribute.
90  LogicalResult result = op->fold(/*operands=*/std::nullopt, foldedOp);
91  (void)result;
92  assert(succeeded(result) && "expected ConstantLike op to be foldable");
93 
94  if (auto attr = llvm::dyn_cast<AttrT>(foldedOp.front().get<Attribute>())) {
95  if (bind_value)
96  *bind_value = attr;
97  return true;
98  }
99  return false;
100  }
101 };
102 
103 /// The matcher that matches operations that have the specified attribute
104 /// name, and binds the attribute value.
105 template <typename AttrT>
106 struct AttrOpBinder {
107  /// Creates a matcher instance that binds the attribute value to
108  /// bind_value if match succeeds.
109  AttrOpBinder(StringRef attrName, AttrT *bindValue)
111  /// Creates a matcher instance that doesn't bind if match succeeds.
112  AttrOpBinder(StringRef attrName) : attrName(attrName), bindValue(nullptr) {}
113 
114  bool match(Operation *op) {
115  if (auto attr = op->getAttrOfType<AttrT>(attrName)) {
116  if (bindValue)
117  *bindValue = attr;
118  return true;
119  }
120  return false;
121  }
122  StringRef attrName;
123  AttrT *bindValue;
124 };
125 
126 /// The matcher that matches a constant scalar / vector splat / tensor splat
127 /// float Attribute or Operation and binds the constant float value.
129  FloatAttr::ValueType *bind_value;
130 
131  /// Creates a matcher instance that binds the value to bv if match succeeds.
132  constant_float_value_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
133 
134  bool match(Attribute attr) {
136  if (matcher.match(attr))
137  return true;
138 
139  if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
140  return matcher.match(splatAttr.getSplatValue<Attribute>());
141 
142  return false;
143  }
144 
145  bool match(Operation *op) {
146  Attribute attr;
147  if (!constant_op_binder<Attribute>(&attr).match(op))
148  return false;
149 
150  Type type = op->getResult(0).getType();
151  if (isa<FloatType, VectorType, RankedTensorType>(type))
152  return match(attr);
153 
154  return false;
155  }
156 };
157 
158 /// The matcher that matches a given target constant scalar / vector splat /
159 /// tensor splat float value that fulfills a predicate.
161  bool (*predicate)(const APFloat &);
162 
163  bool match(Attribute attr) {
164  APFloat value(APFloat::Bogus());
165  return constant_float_value_binder(&value).match(attr) && predicate(value);
166  }
167 
168  bool match(Operation *op) {
169  APFloat value(APFloat::Bogus());
170  return constant_float_value_binder(&value).match(op) && predicate(value);
171  }
172 };
173 
174 /// The matcher that matches a constant scalar / vector splat / tensor splat
175 /// integer Attribute or Operation and binds the constant integer value.
177  IntegerAttr::ValueType *bind_value;
178 
179  /// Creates a matcher instance that binds the value to bv if match succeeds.
180  constant_int_value_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
181 
182  bool match(Attribute attr) {
184  if (matcher.match(attr))
185  return true;
186 
187  if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
188  return matcher.match(splatAttr.getSplatValue<Attribute>());
189 
190  return false;
191  }
192 
193  bool match(Operation *op) {
194  Attribute attr;
195  if (!constant_op_binder<Attribute>(&attr).match(op))
196  return false;
197 
198  Type type = op->getResult(0).getType();
199  if (isa<IntegerType, IndexType, VectorType, RankedTensorType>(type))
200  return match(attr);
201 
202  return false;
203  }
204 };
205 
206 /// The matcher that matches a given target constant scalar / vector splat /
207 /// tensor splat integer value that fulfills a predicate.
209  bool (*predicate)(const APInt &);
210 
211  bool match(Attribute attr) {
212  APInt value;
213  return constant_int_value_binder(&value).match(attr) && predicate(value);
214  }
215 
216  bool match(Operation *op) {
217  APInt value;
218  return constant_int_value_binder(&value).match(op) && predicate(value);
219  }
220 };
221 
222 /// The matcher that matches a certain kind of op.
223 template <typename OpClass>
224 struct op_matcher {
225  bool match(Operation *op) { return isa<OpClass>(op); }
226 };
227 
228 /// Trait to check whether T provides a 'match' method with type
229 /// `MatchTarget` (Value, Operation, or Attribute).
230 template <typename T, typename MatchTarget>
232  decltype(std::declval<T>().match(std::declval<MatchTarget>()));
233 
234 /// Statically switch to a Value matcher.
235 template <typename MatcherClass>
236 std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
237  MatcherClass, Value>::value,
238  bool>
239 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
240  return matcher.match(op->getOperand(idx));
241 }
242 
243 /// Statically switch to an Operation matcher.
244 template <typename MatcherClass>
245 std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
246  MatcherClass, Operation *>::value,
247  bool>
248 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
249  if (auto *defOp = op->getOperand(idx).getDefiningOp())
250  return matcher.match(defOp);
251  return false;
252 }
253 
254 /// Terminal matcher, always returns true.
256  bool match(Value op) const { return true; }
257 };
258 
259 /// Terminal matcher, always returns true.
263  bool match(Value op) const {
264  *what = op;
265  return true;
266  }
267 };
268 
269 /// Binds to a specific value and matches it.
272  bool match(Value val) const { return val == value; }
274 };
275 
276 template <typename TupleT, class CallbackT, std::size_t... Is>
277 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
278  std::index_sequence<Is...>) {
279 
280  (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
281  ...);
282 }
283 
284 template <typename... Tys, typename CallbackT>
285 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
286  detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
287  std::make_index_sequence<sizeof...(Tys)>{});
288 }
289 
290 /// RecursivePatternMatcher that composes.
291 template <typename OpType, typename... OperandMatchers>
293  RecursivePatternMatcher(OperandMatchers... matchers)
294  : operandMatchers(matchers...) {}
295  bool match(Operation *op) {
296  if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
297  return false;
298  bool res = true;
299  enumerate(operandMatchers, [&](size_t index, auto &matcher) {
300  res &= matchOperandOrValueAtIndex(op, index, matcher);
301  });
302  return res;
303  }
304  std::tuple<OperandMatchers...> operandMatchers;
305 };
306 
307 } // namespace detail
308 
309 /// Matches a constant foldable operation.
312 }
313 
314 /// Matches a named attribute operation.
315 inline detail::AttrOpMatcher m_Attr(StringRef attrName) {
316  return detail::AttrOpMatcher(attrName);
317 }
318 
319 /// Matches a named operation.
320 inline detail::NameOpMatcher m_Op(StringRef opName) {
321  return detail::NameOpMatcher(opName);
322 }
323 
324 /// Matches a value from a constant foldable operation and writes the value to
325 /// bind_value.
326 template <typename AttrT>
328  return detail::constant_op_binder<AttrT>(bind_value);
329 }
330 
331 /// Matches a named attribute operation and writes the value to bind_value.
332 template <typename AttrT>
333 inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName,
334  AttrT *bindValue) {
335  return detail::AttrOpBinder<AttrT>(attrName, bindValue);
336 }
337 
338 /// Matches a constant scalar / vector splat / tensor splat float (both positive
339 /// and negative) zero.
341  return {[](const APFloat &value) { return value.isZero(); }};
342 }
343 
344 /// Matches a constant scalar / vector splat / tensor splat float positive zero.
346  return {[](const APFloat &value) { return value.isPosZero(); }};
347 }
348 
349 /// Matches a constant scalar / vector splat / tensor splat float negative zero.
351  return {[](const APFloat &value) { return value.isNegZero(); }};
352 }
353 
354 /// Matches a constant scalar / vector splat / tensor splat float ones.
356  return {[](const APFloat &value) {
357  return APFloat(value.getSemantics(), 1) == value;
358  }};
359 }
360 
361 /// Matches a constant scalar / vector splat / tensor splat float positive
362 /// infinity.
364  return {[](const APFloat &value) {
365  return !value.isNegative() && value.isInfinity();
366  }};
367 }
368 
369 /// Matches a constant scalar / vector splat / tensor splat float negative
370 /// infinity.
372  return {[](const APFloat &value) {
373  return value.isNegative() && value.isInfinity();
374  }};
375 }
376 
377 /// Matches a constant scalar / vector splat / tensor splat integer zero.
379  return {[](const APInt &value) { return 0 == value; }};
380 }
381 
382 /// Matches a constant scalar / vector splat / tensor splat integer that is any
383 /// non-zero value.
385  return {[](const APInt &value) { return 0 != value; }};
386 }
387 
388 /// Matches a constant scalar / vector splat / tensor splat integer one.
390  return {[](const APInt &value) { return 1 == value; }};
391 }
392 
393 /// Matches the given OpClass.
394 template <typename OpClass>
397 }
398 
399 /// Entry point for matching a pattern over a Value.
400 template <typename Pattern>
401 inline bool matchPattern(Value value, const Pattern &pattern) {
402  assert(value);
403  // TODO: handle other cases
404  if (auto *op = value.getDefiningOp())
405  return const_cast<Pattern &>(pattern).match(op);
406  return false;
407 }
408 
409 /// Entry point for matching a pattern over an Operation.
410 template <typename Pattern>
411 inline bool matchPattern(Operation *op, const Pattern &pattern) {
412  assert(op);
413  return const_cast<Pattern &>(pattern).match(op);
414 }
415 
416 /// Entry point for matching a pattern over an Attribute. Returns `false`
417 /// when `attr` is null.
418 template <typename Pattern>
419 inline bool matchPattern(Attribute attr, const Pattern &pattern) {
420  static_assert(llvm::is_detected<detail::has_compatible_matcher_t, Pattern,
421  Attribute>::value,
422  "Pattern does not support matching Attributes");
423  if (!attr)
424  return false;
425  return const_cast<Pattern &>(pattern).match(attr);
426 }
427 
428 /// Matches a constant holding a scalar/vector/tensor float (splat) and
429 /// writes the float value to bind_value.
430 inline detail::constant_float_value_binder
431 m_ConstantFloat(FloatAttr::ValueType *bind_value) {
432  return detail::constant_float_value_binder(bind_value);
433 }
434 
435 /// Matches a constant holding a scalar/vector/tensor integer (splat) and
436 /// writes the integer value to bind_value.
437 inline detail::constant_int_value_binder
438 m_ConstantInt(IntegerAttr::ValueType *bind_value) {
439  return detail::constant_int_value_binder(bind_value);
440 }
441 
442 template <typename OpType, typename... Matchers>
443 auto m_Op(Matchers... matchers) {
444  return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
445 }
446 
447 namespace matchers {
448 inline auto m_Any() { return detail::AnyValueMatcher(); }
449 inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); }
450 inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
451 } // namespace matchers
452 
453 } // namespace mlir
454 
455 #endif // MLIR_IR_MATCHERS_H
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence< Is... >)
Definition: Matchers.h:277
std::enable_if_t< llvm::is_detected< detail::has_compatible_matcher_t, MatcherClass, Value >::value, bool > matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher)
Statically switch to a Value matcher.
Definition: Matchers.h:239
decltype(std::declval< T >().match(std::declval< MatchTarget >())) has_compatible_matcher_t
Trait to check whether T provides a 'match' method with type MatchTarget (Value, Operation,...
Definition: Matchers.h:232
auto m_Val(Value v)
Definition: Matchers.h:450
auto m_Any()
Definition: Matchers.h:448
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
detail::AttrOpMatcher m_Attr(StringRef attrName)
Matches a named attribute operation.
Definition: Matchers.h:315
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
detail::NameOpMatcher m_Op(StringRef opName)
Matches a named operation.
Definition: Matchers.h:320
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:345
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:340
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:384
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:371
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:350
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:363
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition: Matchers.h:431
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:355
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Terminal matcher, always returns true.
Definition: Matchers.h:260
Terminal matcher, always returns true.
Definition: Matchers.h:255
bool match(Value op) const
Definition: Matchers.h:256
The matcher that matches operations that have the specified attribute name, and binds the attribute v...
Definition: Matchers.h:106
AttrOpBinder(StringRef attrName, AttrT *bindValue)
Creates a matcher instance that binds the attribute value to bind_value if match succeeds.
Definition: Matchers.h:109
bool match(Operation *op)
Definition: Matchers.h:114
AttrOpBinder(StringRef attrName)
Creates a matcher instance that doesn't bind if match succeeds.
Definition: Matchers.h:112
The matcher that matches operations that have the specified attribute name.
Definition: Matchers.h:65
bool match(Operation *op)
Definition: Matchers.h:67
AttrOpMatcher(StringRef attrName)
Definition: Matchers.h:66
The matcher that matches operations that have the specified op name.
Definition: Matchers.h:57
NameOpMatcher(StringRef name)
Definition: Matchers.h:58
bool match(Operation *op)
Definition: Matchers.h:59
Binds to a specific value and matches it.
Definition: Matchers.h:270
bool match(Value val) const
Definition: Matchers.h:272
RecursivePatternMatcher that composes.
Definition: Matchers.h:292
RecursivePatternMatcher(OperandMatchers... matchers)
Definition: Matchers.h:293
std::tuple< OperandMatchers... > operandMatchers
Definition: Matchers.h:304
The matcher that matches a certain kind of Attribute and binds the value inside the Attribute.
Definition: Matchers.h:36
attr_value_binder(ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:40
bool match(Attribute attr)
Definition: Matchers.h:42
The matcher that matches a given target constant scalar / vector splat / tensor splat float value tha...
Definition: Matchers.h:160
The matcher that matches a constant scalar / vector splat / tensor splat float Attribute or Operation...
Definition: Matchers.h:128
constant_float_value_binder(FloatAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:132
FloatAttr::ValueType * bind_value
Definition: Matchers.h:129
The matcher that matches a given target constant scalar / vector splat / tensor splat integer value t...
Definition: Matchers.h:208
The matcher that matches a constant scalar / vector splat / tensor splat integer Attribute or Operati...
Definition: Matchers.h:176
constant_int_value_binder(IntegerAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:180
IntegerAttr::ValueType * bind_value
Definition: Matchers.h:177
The matcher that matches operations that have the ConstantLike trait, and binds the folded attribute ...
Definition: Matchers.h:75
constant_op_binder()
Creates a matcher instance that doesn't bind if match succeeds.
Definition: Matchers.h:82
constant_op_binder(AttrT *bind_value)
Creates a matcher instance that binds the constant attribute value to bind_value if match succeeds.
Definition: Matchers.h:80
bool match(Operation *op)
Definition: Matchers.h:84
The matcher that matches operations that have the ConstantLike trait.
Definition: Matchers.h:52
bool match(Operation *op)
Definition: Matchers.h:53
The matcher that matches a certain kind of op.
Definition: Matchers.h:224
bool match(Operation *op)
Definition: Matchers.h:225