MLIR  17.0.0git
Matchers.h
Go to the documentation of this file.
1 //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11 // include/llvm/IR/PatternMatch.h.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_IR_MATCHERS_H
16 #define MLIR_IR_MATCHERS_H
17 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 
21 namespace mlir {
22 
23 namespace detail {
24 
25 /// The matcher that matches a certain kind of Attribute and binds the value
26 /// inside the Attribute.
27 template <
28  typename AttrClass,
29  // Require AttrClass to be a derived class from Attribute and get its
30  // value type
31  typename ValueType = typename std::enable_if_t<
32  std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType,
33  // Require the ValueType is not void
34  typename = std::enable_if_t<!std::is_void<ValueType>::value>>
36  ValueType *bind_value;
37 
38  /// Creates a matcher instance that binds the value to bv if match succeeds.
39  attr_value_binder(ValueType *bv) : bind_value(bv) {}
40 
41  bool match(const Attribute &attr) {
42  if (auto intAttr = attr.dyn_cast<AttrClass>()) {
43  *bind_value = intAttr.getValue();
44  return true;
45  }
46  return false;
47  }
48 };
49 
50 /// The matcher that matches operations that have the `ConstantLike` trait.
52  bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
53 };
54 
55 /// The matcher that matches operations that have the `ConstantLike` trait, and
56 /// binds the folded attribute value.
57 template <typename AttrT>
59  AttrT *bind_value;
60 
61  /// Creates a matcher instance that binds the constant attribute value to
62  /// bind_value if match succeeds.
64  /// Creates a matcher instance that doesn't bind if match succeeds.
66 
67  bool match(Operation *op) {
68  if (!op->hasTrait<OpTrait::ConstantLike>())
69  return false;
70 
71  // Fold the constant to an attribute.
73  LogicalResult result = op->fold(/*operands=*/std::nullopt, foldedOp);
74  (void)result;
75  assert(succeeded(result) && "expected ConstantLike op to be foldable");
76 
77  if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
78  if (bind_value)
79  *bind_value = attr;
80  return true;
81  }
82  return false;
83  }
84 };
85 
86 /// The matcher that matches a constant scalar / vector splat / tensor splat
87 /// float operation and binds the constant float value.
89  FloatAttr::ValueType *bind_value;
90 
91  /// Creates a matcher instance that binds the value to bv if match succeeds.
92  constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
93 
94  bool match(Operation *op) {
95  Attribute attr;
96  if (!constant_op_binder<Attribute>(&attr).match(op))
97  return false;
98  auto type = op->getResult(0).getType();
99 
100  if (type.isa<FloatType>())
102  if (type.isa<VectorType, RankedTensorType>()) {
103  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
105  .match(splatAttr.getSplatValue<Attribute>());
106  }
107  }
108  return false;
109  }
110 };
111 
112 /// The matcher that matches a given target constant scalar / vector splat /
113 /// tensor splat float value that fulfills a predicate.
115  bool (*predicate)(const APFloat &);
116 
117  bool match(Operation *op) {
118  APFloat value(APFloat::Bogus());
119  return constant_float_op_binder(&value).match(op) && predicate(value);
120  }
121 };
122 
123 /// The matcher that matches a constant scalar / vector splat / tensor splat
124 /// integer operation and binds the constant integer value.
126  IntegerAttr::ValueType *bind_value;
127 
128  /// Creates a matcher instance that binds the value to bv if match succeeds.
129  constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
130 
131  bool match(Operation *op) {
132  Attribute attr;
133  if (!constant_op_binder<Attribute>(&attr).match(op))
134  return false;
135  auto type = op->getResult(0).getType();
136 
137  if (type.isa<IntegerType, IndexType>())
139  if (type.isa<VectorType, RankedTensorType>()) {
140  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
142  .match(splatAttr.getSplatValue<Attribute>());
143  }
144  }
145  return false;
146  }
147 };
148 
149 /// The matcher that matches a given target constant scalar / vector splat /
150 /// tensor splat integer value that fulfills a predicate.
152  bool (*predicate)(const APInt &);
153 
154  bool match(Operation *op) {
155  APInt value;
156  return constant_int_op_binder(&value).match(op) && predicate(value);
157  }
158 };
159 
160 /// The matcher that matches a certain kind of op.
161 template <typename OpClass>
162 struct op_matcher {
163  bool match(Operation *op) { return isa<OpClass>(op); }
164 };
165 
166 /// Trait to check whether T provides a 'match' method with type
167 /// `OperationOrValue`.
168 template <typename T, typename OperationOrValue>
170  decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
171 
172 /// Statically switch to a Value matcher.
173 template <typename MatcherClass>
174 std::enable_if_t<llvm::is_detected<detail::has_operation_or_value_matcher_t,
175  MatcherClass, Value>::value,
176  bool>
177 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
178  return matcher.match(op->getOperand(idx));
179 }
180 
181 /// Statically switch to an Operation matcher.
182 template <typename MatcherClass>
183 std::enable_if_t<llvm::is_detected<detail::has_operation_or_value_matcher_t,
184  MatcherClass, Operation *>::value,
185  bool>
186 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
187  if (auto *defOp = op->getOperand(idx).getDefiningOp())
188  return matcher.match(defOp);
189  return false;
190 }
191 
192 /// Terminal matcher, always returns true.
194  bool match(Value op) const { return true; }
195 };
196 
197 /// Terminal matcher, always returns true.
201  bool match(Value op) const {
202  *what = op;
203  return true;
204  }
205 };
206 
207 /// Binds to a specific value and matches it.
210  bool match(Value val) const { return val == value; }
212 };
213 
214 template <typename TupleT, class CallbackT, std::size_t... Is>
215 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
216  std::index_sequence<Is...>) {
217 
218  (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
219  ...);
220 }
221 
222 template <typename... Tys, typename CallbackT>
223 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
224  detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
225  std::make_index_sequence<sizeof...(Tys)>{});
226 }
227 
228 /// RecursivePatternMatcher that composes.
229 template <typename OpType, typename... OperandMatchers>
231  RecursivePatternMatcher(OperandMatchers... matchers)
232  : operandMatchers(matchers...) {}
233  bool match(Operation *op) {
234  if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
235  return false;
236  bool res = true;
237  enumerate(operandMatchers, [&](size_t index, auto &matcher) {
238  res &= matchOperandOrValueAtIndex(op, index, matcher);
239  });
240  return res;
241  }
242  std::tuple<OperandMatchers...> operandMatchers;
243 };
244 
245 } // namespace detail
246 
247 /// Matches a constant foldable operation.
250 }
251 
252 /// Matches a value from a constant foldable operation and writes the value to
253 /// bind_value.
254 template <typename AttrT>
256  return detail::constant_op_binder<AttrT>(bind_value);
257 }
258 
259 /// Matches a constant scalar / vector splat / tensor splat float (both positive
260 /// and negative) zero.
262  return {[](const APFloat &value) { return value.isZero(); }};
263 }
264 
265 /// Matches a constant scalar / vector splat / tensor splat float positive zero.
267  return {[](const APFloat &value) { return value.isPosZero(); }};
268 }
269 
270 /// Matches a constant scalar / vector splat / tensor splat float negative zero.
272  return {[](const APFloat &value) { return value.isNegZero(); }};
273 }
274 
275 /// Matches a constant scalar / vector splat / tensor splat float ones.
277  return {[](const APFloat &value) {
278  return APFloat(value.getSemantics(), 1) == value;
279  }};
280 }
281 
282 /// Matches a constant scalar / vector splat / tensor splat float positive
283 /// infinity.
285  return {[](const APFloat &value) {
286  return !value.isNegative() && value.isInfinity();
287  }};
288 }
289 
290 /// Matches a constant scalar / vector splat / tensor splat float negative
291 /// infinity.
293  return {[](const APFloat &value) {
294  return value.isNegative() && value.isInfinity();
295  }};
296 }
297 
298 /// Matches a constant scalar / vector splat / tensor splat integer zero.
300  return {[](const APInt &value) { return 0 == value; }};
301 }
302 
303 /// Matches a constant scalar / vector splat / tensor splat integer that is any
304 /// non-zero value.
306  return {[](const APInt &value) { return 0 != value; }};
307 }
308 
309 /// Matches a constant scalar / vector splat / tensor splat integer one.
311  return {[](const APInt &value) { return 1 == value; }};
312 }
313 
314 /// Matches the given OpClass.
315 template <typename OpClass>
318 }
319 
320 /// Entry point for matching a pattern over a Value.
321 template <typename Pattern>
322 inline bool matchPattern(Value value, const Pattern &pattern) {
323  // TODO: handle other cases
324  if (auto *op = value.getDefiningOp())
325  return const_cast<Pattern &>(pattern).match(op);
326  return false;
327 }
328 
329 /// Entry point for matching a pattern over an Operation.
330 template <typename Pattern>
331 inline bool matchPattern(Operation *op, const Pattern &pattern) {
332  return const_cast<Pattern &>(pattern).match(op);
333 }
334 
335 /// Matches a constant holding a scalar/vector/tensor float (splat) and
336 /// writes the float value to bind_value.
337 inline detail::constant_float_op_binder
338 m_ConstantFloat(FloatAttr::ValueType *bind_value) {
339  return detail::constant_float_op_binder(bind_value);
340 }
341 
342 /// Matches a constant holding a scalar/vector/tensor integer (splat) and
343 /// writes the integer value to bind_value.
344 inline detail::constant_int_op_binder
345 m_ConstantInt(IntegerAttr::ValueType *bind_value) {
346  return detail::constant_int_op_binder(bind_value);
347 }
348 
349 template <typename OpType, typename... Matchers>
350 auto m_Op(Matchers... matchers) {
351  return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
352 }
353 
354 namespace matchers {
355 inline auto m_Any() { return detail::AnyValueMatcher(); }
356 inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); }
357 inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
358 } // namespace matchers
359 
360 } // namespace mlir
361 
362 #endif // MLIR_IR_MATCHERS_H
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:166
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:499
Value getOperand(unsigned idx)
Definition: Operation.h:329
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:592
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
unsigned getNumOperands()
Definition: Operation.h:325
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
decltype(std::declval< T >().match(std::declval< OperationOrValue >())) has_operation_or_value_matcher_t
Trait to check whether T provides a 'match' method with type OperationOrValue.
Definition: Matchers.h:170
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence< Is... >)
Definition: Matchers.h:215
std::enable_if_t< llvm::is_detected< detail::has_operation_or_value_matcher_t, MatcherClass, Value >::value, bool > matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher)
Statically switch to a Value matcher.
Definition: Matchers.h:177
auto m_Val(Value v)
Definition: Matchers.h:357
auto m_Any()
Definition: Matchers.h:355
This header declares functions that assit transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:322
detail::constant_float_op_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition: Matchers.h:338
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:266
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:299
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:261
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:310
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:345
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:305
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:292
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:271
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition: Matchers.h:316
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:248
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:284
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:276
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Terminal matcher, always returns true.
Definition: Matchers.h:198
Terminal matcher, always returns true.
Definition: Matchers.h:193
bool match(Value op) const
Definition: Matchers.h:194
Binds to a specific value and matches it.
Definition: Matchers.h:208
bool match(Value val) const
Definition: Matchers.h:210
RecursivePatternMatcher that composes.
Definition: Matchers.h:230
RecursivePatternMatcher(OperandMatchers... matchers)
Definition: Matchers.h:231
std::tuple< OperandMatchers... > operandMatchers
Definition: Matchers.h:242
The matcher that matches a certain kind of Attribute and binds the value inside the Attribute.
Definition: Matchers.h:35
attr_value_binder(ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:39
bool match(const Attribute &attr)
Definition: Matchers.h:41
The matcher that matches a constant scalar / vector splat / tensor splat float operation and binds th...
Definition: Matchers.h:88
constant_float_op_binder(FloatAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:92
FloatAttr::ValueType * bind_value
Definition: Matchers.h:89
The matcher that matches a given target constant scalar / vector splat / tensor splat float value tha...
Definition: Matchers.h:114
The matcher that matches a constant scalar / vector splat / tensor splat integer operation and binds ...
Definition: Matchers.h:125
constant_int_op_binder(IntegerAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:129
IntegerAttr::ValueType * bind_value
Definition: Matchers.h:126
The matcher that matches a given target constant scalar / vector splat / tensor splat integer value t...
Definition: Matchers.h:151
The matcher that matches operations that have the ConstantLike trait, and binds the folded attribute ...
Definition: Matchers.h:58
constant_op_binder()
Creates a matcher instance that doesn't bind if match succeeds.
Definition: Matchers.h:65
constant_op_binder(AttrT *bind_value)
Creates a matcher instance that binds the constant attribute value to bind_value if match succeeds.
Definition: Matchers.h:63
bool match(Operation *op)
Definition: Matchers.h:67
The matcher that matches operations that have the ConstantLike trait.
Definition: Matchers.h:51
bool match(Operation *op)
Definition: Matchers.h:52
The matcher that matches a certain kind of op.
Definition: Matchers.h:162
bool match(Operation *op)
Definition: Matchers.h:163