MLIR  16.0.0git
Matchers.h
Go to the documentation of this file.
1 //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11 // include/llvm/IR/PatternMatch.h.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_IR_MATCHERS_H
16 #define MLIR_IR_MATCHERS_H
17 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 
21 namespace mlir {
22 
23 namespace detail {
24 
25 /// The matcher that matches a certain kind of Attribute and binds the value
26 /// inside the Attribute.
27 template <
28  typename AttrClass,
29  // Require AttrClass to be a derived class from Attribute and get its
30  // value type
31  typename ValueType = typename std::enable_if_t<
32  std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType,
33  // Require the ValueType is not void
36  ValueType *bind_value;
37 
38  /// Creates a matcher instance that binds the value to bv if match succeeds.
39  attr_value_binder(ValueType *bv) : bind_value(bv) {}
40 
41  bool match(const Attribute &attr) {
42  if (auto intAttr = attr.dyn_cast<AttrClass>()) {
43  *bind_value = intAttr.getValue();
44  return true;
45  }
46  return false;
47  }
48 };
49 
50 /// Check to see if the specified operation is ConstantLike. This includes some
51 /// quick filters to avoid a semi-expensive test in the common case.
52 static bool isConstantLike(Operation *op) {
53  return op->getNumOperands() == 0 && op->getNumResults() == 1 &&
55 }
56 
57 /// The matcher that matches operations that have the `ConstantLike` trait.
59  bool match(Operation *op) { return isConstantLike(op); }
60 };
61 
62 /// The matcher that matches operations that have the `ConstantLike` trait, and
63 /// binds the folded attribute value.
64 template <typename AttrT>
66  AttrT *bind_value;
67 
68  /// Creates a matcher instance that binds the constant attribute value to
69  /// bind_value if match succeeds.
71  /// Creates a matcher instance that doesn't bind if match succeeds.
73 
74  bool match(Operation *op) {
75  if (!isConstantLike(op))
76  return false;
77 
78  // Fold the constant to an attribute.
80  LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp);
81  (void)result;
82  assert(succeeded(result) && "expected ConstantLike op to be foldable");
83 
84  if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
85  if (bind_value)
86  *bind_value = attr;
87  return true;
88  }
89  return false;
90  }
91 };
92 
93 /// The matcher that matches a constant scalar / vector splat / tensor splat
94 /// float operation and binds the constant float value.
96  FloatAttr::ValueType *bind_value;
97 
98  /// Creates a matcher instance that binds the value to bv if match succeeds.
99  constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
100 
101  bool match(Operation *op) {
102  Attribute attr;
103  if (!constant_op_binder<Attribute>(&attr).match(op))
104  return false;
105  auto type = op->getResult(0).getType();
106 
107  if (type.isa<FloatType>())
109  if (type.isa<VectorType, RankedTensorType>()) {
110  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
112  .match(splatAttr.getSplatValue<Attribute>());
113  }
114  }
115  return false;
116  }
117 };
118 
119 /// The matcher that matches a given target constant scalar / vector splat /
120 /// tensor splat float value that fulfills a predicate.
122  bool (*predicate)(const APFloat &);
123 
124  bool match(Operation *op) {
125  APFloat value(APFloat::Bogus());
127  }
128 };
129 
130 /// The matcher that matches a constant scalar / vector splat / tensor splat
131 /// integer operation and binds the constant integer value.
133  IntegerAttr::ValueType *bind_value;
134 
135  /// Creates a matcher instance that binds the value to bv if match succeeds.
136  constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
137 
138  bool match(Operation *op) {
139  Attribute attr;
140  if (!constant_op_binder<Attribute>(&attr).match(op))
141  return false;
142  auto type = op->getResult(0).getType();
143 
144  if (type.isa<IntegerType, IndexType>())
146  if (type.isa<VectorType, RankedTensorType>()) {
147  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
149  .match(splatAttr.getSplatValue<Attribute>());
150  }
151  }
152  return false;
153  }
154 };
155 
156 /// The matcher that matches a given target constant scalar / vector splat /
157 /// tensor splat integer value that fulfills a predicate.
159  bool (*predicate)(const APInt &);
160 
161  bool match(Operation *op) {
162  APInt value;
164  }
165 };
166 
167 /// The matcher that matches a certain kind of op.
168 template <typename OpClass>
169 struct op_matcher {
170  bool match(Operation *op) { return isa<OpClass>(op); }
171 };
172 
173 /// Trait to check whether T provides a 'match' method with type
174 /// `OperationOrValue`.
175 template <typename T, typename OperationOrValue>
177  decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
178 
179 /// Statically switch to a Value matcher.
180 template <typename MatcherClass>
181 std::enable_if_t<llvm::is_detected<detail::has_operation_or_value_matcher_t,
182  MatcherClass, Value>::value,
183  bool>
184 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
185  return matcher.match(op->getOperand(idx));
186 }
187 
188 /// Statically switch to an Operation matcher.
189 template <typename MatcherClass>
190 std::enable_if_t<llvm::is_detected<detail::has_operation_or_value_matcher_t,
191  MatcherClass, Operation *>::value,
192  bool>
193 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
194  if (auto *defOp = op->getOperand(idx).getDefiningOp())
195  return matcher.match(defOp);
196  return false;
197 }
198 
199 /// Terminal matcher, always returns true.
201  bool match(Value op) const { return true; }
202 };
203 
204 /// Terminal matcher, always returns true.
208  bool match(Value op) const {
209  *what = op;
210  return true;
211  }
212 };
213 
214 /// Binds to a specific value and matches it.
217  bool match(Value val) const { return val == value; }
219 };
220 
221 template <typename TupleT, class CallbackT, std::size_t... Is>
222 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
223  std::index_sequence<Is...>) {
224 
225  (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
226  ...);
227 }
228 
229 template <typename... Tys, typename CallbackT>
230 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
231  detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
232  std::make_index_sequence<sizeof...(Tys)>{});
233 }
234 
235 /// RecursivePatternMatcher that composes.
236 template <typename OpType, typename... OperandMatchers>
238  RecursivePatternMatcher(OperandMatchers... matchers)
239  : operandMatchers(matchers...) {}
240  bool match(Operation *op) {
241  if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
242  return false;
243  bool res = true;
244  enumerate(operandMatchers, [&](size_t index, auto &matcher) {
245  res &= matchOperandOrValueAtIndex(op, index, matcher);
246  });
247  return res;
248  }
249  std::tuple<OperandMatchers...> operandMatchers;
250 };
251 
252 } // namespace detail
253 
254 /// Matches a constant foldable operation.
257 }
258 
259 /// Matches a value from a constant foldable operation and writes the value to
260 /// bind_value.
261 template <typename AttrT>
263  return detail::constant_op_binder<AttrT>(bind_value);
264 }
265 
266 /// Matches a constant scalar / vector splat / tensor splat float (both positive
267 /// and negative) zero.
269  return {[](const APFloat &value) { return value.isZero(); }};
270 }
271 
272 /// Matches a constant scalar / vector splat / tensor splat float positive zero.
274  return {[](const APFloat &value) { return value.isPosZero(); }};
275 }
276 
277 /// Matches a constant scalar / vector splat / tensor splat float negative zero.
279  return {[](const APFloat &value) { return value.isNegZero(); }};
280 }
281 
282 /// Matches a constant scalar / vector splat / tensor splat float ones.
284  return {[](const APFloat &value) {
285  return APFloat(value.getSemantics(), 1) == value;
286  }};
287 }
288 
289 /// Matches a constant scalar / vector splat / tensor splat float positive
290 /// infinity.
292  return {[](const APFloat &value) {
293  return !value.isNegative() && value.isInfinity();
294  }};
295 }
296 
297 /// Matches a constant scalar / vector splat / tensor splat float negative
298 /// infinity.
300  return {[](const APFloat &value) {
301  return value.isNegative() && value.isInfinity();
302  }};
303 }
304 
305 /// Matches a constant scalar / vector splat / tensor splat integer zero.
307  return {[](const APInt &value) { return 0 == value; }};
308 }
309 
310 /// Matches a constant scalar / vector splat / tensor splat integer that is any
311 /// non-zero value.
313  return {[](const APInt &value) { return 0 != value; }};
314 }
315 
316 /// Matches a constant scalar / vector splat / tensor splat integer one.
318  return {[](const APInt &value) { return 1 == value; }};
319 }
320 
321 /// Matches the given OpClass.
322 template <typename OpClass>
325 }
326 
327 /// Entry point for matching a pattern over a Value.
328 template <typename Pattern>
329 inline bool matchPattern(Value value, const Pattern &pattern) {
330  // TODO: handle other cases
331  if (auto *op = value.getDefiningOp())
332  return const_cast<Pattern &>(pattern).match(op);
333  return false;
334 }
335 
336 /// Entry point for matching a pattern over an Operation.
337 template <typename Pattern>
338 inline bool matchPattern(Operation *op, const Pattern &pattern) {
339  return const_cast<Pattern &>(pattern).match(op);
340 }
341 
342 /// Matches a constant holding a scalar/vector/tensor float (splat) and
343 /// writes the float value to bind_value.
344 inline detail::constant_float_op_binder
345 m_ConstantFloat(FloatAttr::ValueType *bind_value) {
346  return detail::constant_float_op_binder(bind_value);
347 }
348 
349 /// Matches a constant holding a scalar/vector/tensor integer (splat) and
350 /// writes the integer value to bind_value.
351 inline detail::constant_int_op_binder
352 m_ConstantInt(IntegerAttr::ValueType *bind_value) {
353  return detail::constant_int_op_binder(bind_value);
354 }
355 
356 template <typename OpType, typename... Matchers>
357 auto m_Op(Matchers... matchers) {
358  return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
359 }
360 
361 namespace matchers {
362 inline auto m_Any() { return detail::AnyValueMatcher(); }
363 inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); }
364 inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
365 } // namespace matchers
366 
367 } // namespace mlir
368 
369 #endif // MLIR_IR_MATCHERS_H
static constexpr const bool value
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:490
Value getOperand(unsigned idx)
Definition: Operation.h:267
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
unsigned getNumOperands()
Definition: Operation.h:263
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
decltype(std::declval< T >().match(std::declval< OperationOrValue >())) has_operation_or_value_matcher_t
Trait to check whether T provides a 'match' method with type OperationOrValue.
Definition: Matchers.h:177
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence< Is... >)
Definition: Matchers.h:222
std::enable_if_t< llvm::is_detected< detail::has_operation_or_value_matcher_t, MatcherClass, Value >::value, bool > matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher)
Statically switch to a Value matcher.
Definition: Matchers.h:184
static bool isConstantLike(Operation *op)
Check to see if the specified operation is ConstantLike.
Definition: Matchers.h:52
auto m_Val(Value v)
Definition: Matchers.h:364
auto m_Any()
Definition: Matchers.h:362
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:329
detail::constant_float_op_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition: Matchers.h:345
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:273
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:306
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:268
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:317
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:352
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:312
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:299
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:278
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition: Matchers.h:323
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:255
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:291
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:283
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Terminal matcher, always returns true.
Definition: Matchers.h:205
Terminal matcher, always returns true.
Definition: Matchers.h:200
bool match(Value op) const
Definition: Matchers.h:201
Binds to a specific value and matches it.
Definition: Matchers.h:215
bool match(Value val) const
Definition: Matchers.h:217
RecursivePatternMatcher that composes.
Definition: Matchers.h:237
RecursivePatternMatcher(OperandMatchers... matchers)
Definition: Matchers.h:238
std::tuple< OperandMatchers... > operandMatchers
Definition: Matchers.h:249
The matcher that matches a certain kind of Attribute and binds the value inside the Attribute.
Definition: Matchers.h:35
attr_value_binder(ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:39
bool match(const Attribute &attr)
Definition: Matchers.h:41
The matcher that matches a constant scalar / vector splat / tensor splat float operation and binds th...
Definition: Matchers.h:95
constant_float_op_binder(FloatAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:99
FloatAttr::ValueType * bind_value
Definition: Matchers.h:96
The matcher that matches a given target constant scalar / vector splat / tensor splat float value tha...
Definition: Matchers.h:121
The matcher that matches a constant scalar / vector splat / tensor splat integer operation and binds ...
Definition: Matchers.h:132
constant_int_op_binder(IntegerAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:136
IntegerAttr::ValueType * bind_value
Definition: Matchers.h:133
The matcher that matches a given target constant scalar / vector splat / tensor splat integer value t...
Definition: Matchers.h:158
The matcher that matches operations that have the ConstantLike trait, and binds the folded attribute ...
Definition: Matchers.h:65
constant_op_binder()
Creates a matcher instance that doesn't bind if match succeeds.
Definition: Matchers.h:72
constant_op_binder(AttrT *bind_value)
Creates a matcher instance that binds the constant attribute value to bind_value if match succeeds.
Definition: Matchers.h:70
bool match(Operation *op)
Definition: Matchers.h:74
The matcher that matches operations that have the ConstantLike trait.
Definition: Matchers.h:58
bool match(Operation *op)
Definition: Matchers.h:59
The matcher that matches a certain kind of op.
Definition: Matchers.h:169
bool match(Operation *op)
Definition: Matchers.h:170