MLIR  15.0.0git
Matchers.h
Go to the documentation of this file.
1 //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11 // include/llvm/IR/PatternMatch.h.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_IR_MATCHERS_H
16 #define MLIR_IR_MATCHERS_H
17 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 
21 namespace mlir {
22 
23 namespace detail {
24 
25 /// The matcher that matches a certain kind of Attribute and binds the value
26 /// inside the Attribute.
27 template <
28  typename AttrClass,
29  // Require AttrClass to be a derived class from Attribute and get its
30  // value type
31  typename ValueType =
33  AttrClass>::type::ValueType,
34  // Require the ValueType is not void
35  typename = typename std::enable_if<!std::is_void<ValueType>::value>::type>
37  ValueType *bind_value;
38 
39  /// Creates a matcher instance that binds the value to bv if match succeeds.
40  attr_value_binder(ValueType *bv) : bind_value(bv) {}
41 
42  bool match(const Attribute &attr) {
43  if (auto intAttr = attr.dyn_cast<AttrClass>()) {
44  *bind_value = intAttr.getValue();
45  return true;
46  }
47  return false;
48  }
49 };
50 
51 /// Check to see if the specified operation is ConstantLike. This includes some
52 /// quick filters to avoid a semi-expensive test in the common case.
53 static bool isConstantLike(Operation *op) {
54  return op->getNumOperands() == 0 && op->getNumResults() == 1 &&
56 }
57 
58 /// The matcher that matches operations that have the `ConstantLike` trait.
60  bool match(Operation *op) { return isConstantLike(op); }
61 };
62 
63 /// The matcher that matches operations that have the `ConstantLike` trait, and
64 /// binds the folded attribute value.
65 template <typename AttrT>
67  AttrT *bind_value;
68 
69  /// Creates a matcher instance that binds the constant attribute value to
70  /// bind_value if match succeeds.
71  constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
72  /// Creates a matcher instance that doesn't bind if match succeeds.
73  constant_op_binder() : bind_value(nullptr) {}
74 
75  bool match(Operation *op) {
76  if (!isConstantLike(op))
77  return false;
78 
79  // Fold the constant to an attribute.
81  LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp);
82  (void)result;
83  assert(succeeded(result) && "expected ConstantLike op to be foldable");
84 
85  if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
86  if (bind_value)
87  *bind_value = attr;
88  return true;
89  }
90  return false;
91  }
92 };
93 
94 /// The matcher that matches a constant scalar / vector splat / tensor splat
95 /// float operation and binds the constant float value.
97  FloatAttr::ValueType *bind_value;
98 
99  /// Creates a matcher instance that binds the value to bv if match succeeds.
100  constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
101 
102  bool match(Operation *op) {
103  Attribute attr;
104  if (!constant_op_binder<Attribute>(&attr).match(op))
105  return false;
106  auto type = op->getResult(0).getType();
107 
108  if (type.isa<FloatType>())
110  if (type.isa<VectorType, RankedTensorType>()) {
111  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
113  .match(splatAttr.getSplatValue<Attribute>());
114  }
115  }
116  return false;
117  }
118 };
119 
120 /// The matcher that matches a given target constant scalar / vector splat /
121 /// tensor splat float value that fulfills a predicate.
123  bool (*predicate)(const APFloat &);
124 
125  bool match(Operation *op) {
126  APFloat value(APFloat::Bogus());
127  return constant_float_op_binder(&value).match(op) && predicate(value);
128  }
129 };
130 
131 /// The matcher that matches a constant scalar / vector splat / tensor splat
132 /// integer operation and binds the constant integer value.
134  IntegerAttr::ValueType *bind_value;
135 
136  /// Creates a matcher instance that binds the value to bv if match succeeds.
137  constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
138 
139  bool match(Operation *op) {
140  Attribute attr;
141  if (!constant_op_binder<Attribute>(&attr).match(op))
142  return false;
143  auto type = op->getResult(0).getType();
144 
145  if (type.isa<IntegerType, IndexType>())
147  if (type.isa<VectorType, RankedTensorType>()) {
148  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
150  .match(splatAttr.getSplatValue<Attribute>());
151  }
152  }
153  return false;
154  }
155 };
156 
157 /// The matcher that matches a given target constant scalar / vector splat /
158 /// tensor splat integer value that fulfills a predicate.
160  bool (*predicate)(const APInt &);
161 
162  bool match(Operation *op) {
163  APInt value;
164  return constant_int_op_binder(&value).match(op) && predicate(value);
165  }
166 };
167 
168 /// The matcher that matches a certain kind of op.
169 template <typename OpClass>
170 struct op_matcher {
171  bool match(Operation *op) { return isa<OpClass>(op); }
172 };
173 
174 /// Trait to check whether T provides a 'match' method with type
175 /// `OperationOrValue`.
176 template <typename T, typename OperationOrValue>
178  decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
179 
180 /// Statically switch to a Value matcher.
181 template <typename MatcherClass>
182 typename std::enable_if_t<
183  llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
184  Value>::value,
185  bool>
186 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
187  return matcher.match(op->getOperand(idx));
188 }
189 
190 /// Statically switch to an Operation matcher.
191 template <typename MatcherClass>
192 typename std::enable_if_t<
193  llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
194  Operation *>::value,
195  bool>
196 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
197  if (auto *defOp = op->getOperand(idx).getDefiningOp())
198  return matcher.match(defOp);
199  return false;
200 }
201 
202 /// Terminal matcher, always returns true.
204  bool match(Value op) const { return true; }
205 };
206 
207 /// Terminal matcher, always returns true.
210  AnyCapturedValueMatcher(Value *what) : what(what) {}
211  bool match(Value op) const {
212  *what = op;
213  return true;
214  }
215 };
216 
217 /// Binds to a specific value and matches it.
220  bool match(Value val) const { return val == value; }
222 };
223 
224 template <typename TupleT, class CallbackT, std::size_t... Is>
225 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
226  std::index_sequence<Is...>) {
227  (void)std::initializer_list<int>{
228  0,
229  (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
230  0)...};
231 }
232 
233 template <typename... Tys, typename CallbackT>
234 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
235  detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
236  std::make_index_sequence<sizeof...(Tys)>{});
237 }
238 
239 /// RecursivePatternMatcher that composes.
240 template <typename OpType, typename... OperandMatchers>
242  RecursivePatternMatcher(OperandMatchers... matchers)
243  : operandMatchers(matchers...) {}
244  bool match(Operation *op) {
245  if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
246  return false;
247  bool res = true;
248  enumerate(operandMatchers, [&](size_t index, auto &matcher) {
249  res &= matchOperandOrValueAtIndex(op, index, matcher);
250  });
251  return res;
252  }
253  std::tuple<OperandMatchers...> operandMatchers;
254 };
255 
256 } // namespace detail
257 
258 /// Matches a constant foldable operation.
261 }
262 
263 /// Matches a value from a constant foldable operation and writes the value to
264 /// bind_value.
265 template <typename AttrT>
268 }
269 
270 /// Matches a constant scalar / vector splat / tensor splat float (both positive
271 /// and negative) zero.
273  return {[](const APFloat &value) { return value.isZero(); }};
274 }
275 
276 /// Matches a constant scalar / vector splat / tensor splat float positive zero.
278  return {[](const APFloat &value) { return value.isPosZero(); }};
279 }
280 
281 /// Matches a constant scalar / vector splat / tensor splat float negative zero.
283  return {[](const APFloat &value) { return value.isNegZero(); }};
284 }
285 
286 /// Matches a constant scalar / vector splat / tensor splat float ones.
288  return {[](const APFloat &value) {
289  return APFloat(value.getSemantics(), 1) == value;
290  }};
291 }
292 
293 /// Matches a constant scalar / vector splat / tensor splat float positive
294 /// infinity.
296  return {[](const APFloat &value) {
297  return !value.isNegative() && value.isInfinity();
298  }};
299 }
300 
301 /// Matches a constant scalar / vector splat / tensor splat float negative
302 /// infinity.
304  return {[](const APFloat &value) {
305  return value.isNegative() && value.isInfinity();
306  }};
307 }
308 
309 /// Matches a constant scalar / vector splat / tensor splat integer zero.
311  return {[](const APInt &value) { return 0 == value; }};
312 }
313 
314 /// Matches a constant scalar / vector splat / tensor splat integer that is any
315 /// non-zero value.
317  return {[](const APInt &value) { return 0 != value; }};
318 }
319 
320 /// Matches a constant scalar / vector splat / tensor splat integer one.
322  return {[](const APInt &value) { return 1 == value; }};
323 }
324 
325 /// Matches the given OpClass.
326 template <typename OpClass>
329 }
330 
331 /// Entry point for matching a pattern over a Value.
332 template <typename Pattern>
333 inline bool matchPattern(Value value, const Pattern &pattern) {
334  // TODO: handle other cases
335  if (auto *op = value.getDefiningOp())
336  return const_cast<Pattern &>(pattern).match(op);
337  return false;
338 }
339 
340 /// Entry point for matching a pattern over an Operation.
341 template <typename Pattern>
342 inline bool matchPattern(Operation *op, const Pattern &pattern) {
343  return const_cast<Pattern &>(pattern).match(op);
344 }
345 
346 /// Matches a constant holding a scalar/vector/tensor float (splat) and
347 /// writes the float value to bind_value.
349 m_ConstantFloat(FloatAttr::ValueType *bind_value) {
350  return detail::constant_float_op_binder(bind_value);
351 }
352 
353 /// Matches a constant holding a scalar/vector/tensor integer (splat) and
354 /// writes the integer value to bind_value.
356 m_ConstantInt(IntegerAttr::ValueType *bind_value) {
357  return detail::constant_int_op_binder(bind_value);
358 }
359 
360 template <typename OpType, typename... Matchers>
361 auto m_Op(Matchers... matchers) {
362  return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
363 }
364 
365 namespace matchers {
366 inline auto m_Any() { return detail::AnyValueMatcher(); }
367 inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); }
368 inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
369 } // namespace matchers
370 
371 } // namespace mlir
372 
373 #endif // MLIR_IR_MATCHERS_H
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:303
constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence< Is... >)
Definition: Matchers.h:225
The matcher that matches operations that have the ConstantLike trait, and binds the folded attribute ...
Definition: Matchers.h:66
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:356
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Binds to a specific value and matches it.
Definition: Matchers.h:218
Value getOperand(unsigned idx)
Definition: Operation.h:274
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:277
static bool isConstantLike(Operation *op)
Check to see if the specified operation is ConstantLike.
Definition: Matchers.h:53
unsigned getNumOperands()
Definition: Operation.h:270
attr_value_binder(ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:40
std::enable_if_t< llvm::is_detected< detail::has_operation_or_value_matcher_t, MatcherClass, Value >::value, bool > matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher)
Statically switch to a Value matcher.
Definition: Matchers.h:186
detail::constant_float_op_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition: Matchers.h:349
constant_int_op_binder(IntegerAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:137
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
auto m_Any(Value *val)
Definition: Matchers.h:367
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:282
constant_float_op_binder(FloatAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:100
static constexpr const bool value
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
The matcher that matches operations that have the ConstantLike trait.
Definition: Matchers.h:59
RecursivePatternMatcher that composes.
Definition: Matchers.h:241
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero...
Definition: Matchers.h:272
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
std::tuple< OperandMatchers... > operandMatchers
Definition: Matchers.h:253
bool match(Value op) const
Definition: Matchers.h:204
bool match(Operation *op)
Definition: Matchers.h:75
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:295
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:287
Attributes are known-constant values of operations.
Definition: Attributes.h:24
The matcher that matches a constant scalar / vector splat / tensor splat float operation and binds th...
Definition: Matchers.h:96
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:526
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:331
constant_op_binder()
Creates a matcher instance that doesn&#39;t bind if match succeeds.
Definition: Matchers.h:73
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:310
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:259
IntegerAttr::ValueType * bind_value
Definition: Matchers.h:134
decltype(std::declval< T >().match(std::declval< OperationOrValue >())) has_operation_or_value_matcher_t
Trait to check whether T provides a &#39;match&#39; method with type OperationOrValue.
Definition: Matchers.h:178
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:321
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition: Matchers.h:327
The matcher that matches a certain kind of op.
Definition: Matchers.h:170
Terminal matcher, always returns true.
Definition: Matchers.h:203
Type getType() const
Return the type of this value.
Definition: Value.h:118
RecursivePatternMatcher(OperandMatchers... matchers)
Definition: Matchers.h:242
The matcher that matches a constant scalar / vector splat / tensor splat integer operation and binds ...
Definition: Matchers.h:133
U dyn_cast() const
Definition: Attributes.h:124
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:333
bool match(Operation *op)
Definition: Matchers.h:171
The matcher that matches a given target constant scalar / vector splat / tensor splat integer value t...
Definition: Matchers.h:159
The matcher that matches a given target constant scalar / vector splat / tensor splat float value tha...
Definition: Matchers.h:122
constant_op_binder(AttrT *bind_value)
Creates a matcher instance that binds the constant attribute value to bind_value if match succeeds...
Definition: Matchers.h:71
bool match(Value val) const
Definition: Matchers.h:220
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
auto m_Val(Value v)
Definition: Matchers.h:368
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:496
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:328
bool match(Operation *op)
Definition: Matchers.h:60
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value...
Definition: Matchers.h:316
bool match(const Attribute &attr)
Definition: Matchers.h:42
This class provides the API for a sub-set of ops that are known to be constant-like.
FloatAttr::ValueType * bind_value
Definition: Matchers.h:97
The matcher that matches a certain kind of Attribute and binds the value inside the Attribute...
Definition: Matchers.h:36
Terminal matcher, always returns true.
Definition: Matchers.h:208