MLIR  19.0.0git
PDLPatternMatch.h.inc
Go to the documentation of this file.
1 //===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PDLPATTERNMATCH_H
10 #define MLIR_IR_PDLPATTERNMATCH_H
11 
12 #include "mlir/Config/mlir-config.h"
13 
14 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 
18 namespace mlir {
19 //===----------------------------------------------------------------------===//
20 // PDL Patterns
21 //===----------------------------------------------------------------------===//
22 
23 //===----------------------------------------------------------------------===//
24 // PDLValue
25 
26 /// Storage type of byte-code interpreter values. These are passed to constraint
27 /// functions as arguments.
28 class PDLValue {
29 public:
30  /// The underlying kind of a PDL value.
31  enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
32 
33  /// Construct a new PDL value.
34  PDLValue(const PDLValue &other) = default;
35  PDLValue(std::nullptr_t = nullptr) {}
36  PDLValue(Attribute value)
37  : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
38  PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
39  PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
40  PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
41  PDLValue(Value value)
42  : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
43  PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
44 
45  /// Returns true if the type of the held value is `T`.
46  template <typename T>
47  bool isa() const {
48  assert(value && "isa<> used on a null value");
49  return kind == getKindOf<T>();
50  }
51 
52  /// Attempt to dynamically cast this value to type `T`, returns null if this
53  /// value is not an instance of `T`.
54  template <typename T,
55  typename ResultT = std::conditional_t<
56  std::is_convertible<T, bool>::value, T, std::optional<T>>>
57  ResultT dyn_cast() const {
58  return isa<T>() ? castImpl<T>() : ResultT();
59  }
60 
61  /// Cast this value to type `T`, asserts if this value is not an instance of
62  /// `T`.
63  template <typename T>
64  T cast() const {
65  assert(isa<T>() && "expected value to be of type `T`");
66  return castImpl<T>();
67  }
68 
69  /// Get an opaque pointer to the value.
70  const void *getAsOpaquePointer() const { return value; }
71 
72  /// Return if this value is null or not.
73  explicit operator bool() const { return value; }
74 
75  /// Return the kind of this value.
76  Kind getKind() const { return kind; }
77 
78  /// Print this value to the provided output stream.
79  void print(raw_ostream &os) const;
80 
81  /// Print the specified value kind to an output stream.
82  static void print(raw_ostream &os, Kind kind);
83 
84 private:
85  /// Find the index of a given type in a range of other types.
86  template <typename...>
87  struct index_of_t;
88  template <typename T, typename... R>
89  struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
90  template <typename T, typename F, typename... R>
91  struct index_of_t<T, F, R...>
92  : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
93 
94  /// Return the kind used for the given T.
95  template <typename T>
96  static Kind getKindOf() {
97  return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
98  TypeRange, Value, ValueRange>::value);
99  }
100 
101  /// The internal implementation of `cast`, that returns the underlying value
102  /// as the given type `T`.
103  template <typename T>
104  std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
105  castImpl() const {
106  return T::getFromOpaquePointer(value);
107  }
108  template <typename T>
109  std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
110  castImpl() const {
111  return *reinterpret_cast<T *>(const_cast<void *>(value));
112  }
113  template <typename T>
114  std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
115  return reinterpret_cast<T>(const_cast<void *>(value));
116  }
117 
118  /// The internal opaque representation of a PDLValue.
119  const void *value{nullptr};
120  /// The kind of the opaque value.
121  Kind kind{Kind::Attribute};
122 };
123 
124 inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
125  value.print(os);
126  return os;
127 }
128 
129 inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
130  PDLValue::print(os, kind);
131  return os;
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // PDLResultList
136 
137 /// The class represents a list of PDL results, returned by a native rewrite
138 /// method. It provides the mechanism with which to pass PDLValues back to the
139 /// PDL bytecode.
140 class PDLResultList {
141 public:
142  /// Push a new Attribute value onto the result list.
143  void push_back(Attribute value) { results.push_back(value); }
144 
145  /// Push a new Operation onto the result list.
146  void push_back(Operation *value) { results.push_back(value); }
147 
148  /// Push a new Type onto the result list.
149  void push_back(Type value) { results.push_back(value); }
150 
151  /// Push a new TypeRange onto the result list.
152  void push_back(TypeRange value) {
153  // The lifetime of a TypeRange can't be guaranteed, so we'll need to
154  // allocate a storage for it.
155  llvm::OwningArrayRef<Type> storage(value.size());
156  llvm::copy(value, storage.begin());
157  allocatedTypeRanges.emplace_back(std::move(storage));
158  typeRanges.push_back(allocatedTypeRanges.back());
159  results.push_back(&typeRanges.back());
160  }
161  void push_back(ValueTypeRange<OperandRange> value) {
162  typeRanges.push_back(value);
163  results.push_back(&typeRanges.back());
164  }
165  void push_back(ValueTypeRange<ResultRange> value) {
166  typeRanges.push_back(value);
167  results.push_back(&typeRanges.back());
168  }
169 
170  /// Push a new Value onto the result list.
171  void push_back(Value value) { results.push_back(value); }
172 
173  /// Push a new ValueRange onto the result list.
174  void push_back(ValueRange value) {
175  // The lifetime of a ValueRange can't be guaranteed, so we'll need to
176  // allocate a storage for it.
177  llvm::OwningArrayRef<Value> storage(value.size());
178  llvm::copy(value, storage.begin());
179  allocatedValueRanges.emplace_back(std::move(storage));
180  valueRanges.push_back(allocatedValueRanges.back());
181  results.push_back(&valueRanges.back());
182  }
183  void push_back(OperandRange value) {
184  valueRanges.push_back(value);
185  results.push_back(&valueRanges.back());
186  }
187  void push_back(ResultRange value) {
188  valueRanges.push_back(value);
189  results.push_back(&valueRanges.back());
190  }
191 
192 protected:
193  /// Create a new result list with the expected number of results.
194  PDLResultList(unsigned maxNumResults) {
195  // For now just reserve enough space for all of the results. We could do
196  // separate counts per range type, but it isn't really worth it unless there
197  // are a "large" number of results.
198  typeRanges.reserve(maxNumResults);
199  valueRanges.reserve(maxNumResults);
200  }
201 
202  /// The PDL results held by this list.
203  SmallVector<PDLValue> results;
204  /// Memory used to store ranges held by the list.
205  SmallVector<TypeRange> typeRanges;
206  SmallVector<ValueRange> valueRanges;
207  /// Memory allocated to store ranges in the result list whose lifetime was
208  /// generated in the native function.
209  SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
210  SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
211 };
212 
213 //===----------------------------------------------------------------------===//
214 // PDLPatternConfig
215 
216 /// An individual configuration for a pattern, which can be accessed by native
217 /// functions via the PDLPatternConfigSet. This allows for injecting additional
218 /// configuration into PDL patterns that is specific to certain compilation
219 /// flows.
220 class PDLPatternConfig {
221 public:
222  virtual ~PDLPatternConfig() = default;
223 
224  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
225  /// pattern. These can be used to setup any specific state necessary for the
226  /// rewrite.
227  virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
228  virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
229 
230  /// Return the TypeID that represents this configuration.
231  TypeID getTypeID() const { return id; }
232 
233 protected:
234  PDLPatternConfig(TypeID id) : id(id) {}
235 
236 private:
237  TypeID id;
238 };
239 
240 /// This class provides a base class for users implementing a type of pattern
241 /// configuration.
242 template <typename T>
243 class PDLPatternConfigBase : public PDLPatternConfig {
244 public:
245  /// Support LLVM style casting.
246  static bool classof(const PDLPatternConfig *config) {
247  return config->getTypeID() == getConfigID();
248  }
249 
250  /// Return the type id used for this configuration.
251  static TypeID getConfigID() { return TypeID::get<T>(); }
252 
253 protected:
254  PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
255 };
256 
257 /// This class contains a set of configurations for a specific pattern.
258 /// Configurations are uniqued by TypeID, meaning that only one configuration of
259 /// each type is allowed.
260 class PDLPatternConfigSet {
261 public:
262  PDLPatternConfigSet() = default;
263 
264  /// Construct a set with the given configurations.
265  template <typename... ConfigsT>
266  PDLPatternConfigSet(ConfigsT &&...configs) {
267  (addConfig(std::forward<ConfigsT>(configs)), ...);
268  }
269 
270  /// Get the configuration defined by the given type. Asserts that the
271  /// configuration of the provided type exists.
272  template <typename T>
273  const T &get() const {
274  const T *config = tryGet<T>();
275  assert(config && "configuration not found");
276  return *config;
277  }
278 
279  /// Get the configuration defined by the given type, returns nullptr if the
280  /// configuration does not exist.
281  template <typename T>
282  const T *tryGet() const {
283  for (const auto &configIt : configs)
284  if (const T *config = dyn_cast<T>(configIt.get()))
285  return config;
286  return nullptr;
287  }
288 
289  /// Notify the configurations within this set at the beginning or end of a
290  /// rewrite of a matched pattern.
291  void notifyRewriteBegin(PatternRewriter &rewriter) {
292  for (const auto &config : configs)
293  config->notifyRewriteBegin(rewriter);
294  }
295  void notifyRewriteEnd(PatternRewriter &rewriter) {
296  for (const auto &config : configs)
297  config->notifyRewriteEnd(rewriter);
298  }
299 
300 protected:
301  /// Add a configuration to the set.
302  template <typename T>
303  void addConfig(T &&config) {
304  assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
305  configs.emplace_back(
306  std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
307  }
308 
309  /// The set of configurations for this pattern. This uses a vector instead of
310  /// a map with the expectation that the number of configurations per set is
311  /// small (<= 1).
312  SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
313 };
314 
315 //===----------------------------------------------------------------------===//
316 // PDLPatternModule
317 
318 /// A generic PDL pattern constraint function. This function applies a
319 /// constraint to a given set of opaque PDLValue entities. Returns success if
320 /// the constraint successfully held, failure otherwise.
321 using PDLConstraintFunction = std::function<LogicalResult(
322  PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
323 
324 /// A native PDL rewrite function. This function performs a rewrite on the
325 /// given set of values. Any results from this rewrite that should be passed
326 /// back to PDL should be added to the provided result list. This method is only
327 /// invoked when the corresponding match was successful. Returns failure if an
328 /// invariant of the rewrite was broken (certain rewriters may recover from
329 /// partial pattern application).
330 using PDLRewriteFunction = std::function<LogicalResult(
331  PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
332 
333 namespace detail {
334 namespace pdl_function_builder {
335 /// A utility variable that always resolves to false. This is useful for static
336 /// asserts that are always false, but only should fire in certain templated
337 /// constructs. For example, if a templated function should never be called, the
338 /// function could be defined as:
339 ///
340 /// template <typename T>
341 /// void foo() {
342 /// static_assert(always_false<T>, "This function should never be called");
343 /// }
344 ///
345 template <class... T>
346 constexpr bool always_false = false;
347 
348 //===----------------------------------------------------------------------===//
349 // PDL Function Builder: Type Processing
350 //===----------------------------------------------------------------------===//
351 
352 /// This struct provides a convenient way to determine how to process a given
353 /// type as either a PDL parameter, or a result value. This allows for
354 /// supporting complex types in constraint and rewrite functions, without
355 /// requiring the user to hand-write the necessary glue code themselves.
356 /// Specializations of this class should implement the following methods to
357 /// enable support as a PDL argument or result type:
358 ///
359 /// static LogicalResult verifyAsArg(
360 /// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
361 /// size_t argIdx);
362 ///
363 /// * This method verifies that the given PDLValue is valid for use as a
364 /// value of `T`.
365 ///
366 /// static T processAsArg(PDLValue pdlValue);
367 ///
368 /// * This method processes the given PDLValue as a value of `T`.
369 ///
370 /// static void processAsResult(PatternRewriter &, PDLResultList &results,
371 /// const T &value);
372 ///
373 /// * This method processes the given value of `T` as the result of a
374 /// function invocation. The method should package the value into an
375 /// appropriate form and append it to the given result list.
376 ///
377 /// If the type `T` is based on a higher order value, consider using
378 /// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
379 /// the implementation.
380 ///
381 template <typename T, typename Enable = void>
382 struct ProcessPDLValue;
383 
384 /// This struct provides a simplified model for processing types that are based
385 /// on another type, e.g. APInt is based on the handling for IntegerAttr. This
386 /// allows for building the necessary processing functions on top of the base
387 /// value instead of a PDLValue. Derived users should implement the following
388 /// (which subsume the ProcessPDLValue variants):
389 ///
390 /// static LogicalResult verifyAsArg(
391 /// function_ref<LogicalResult(const Twine &)> errorFn,
392 /// const BaseT &baseValue, size_t argIdx);
393 ///
394 /// * This method verifies that the given PDLValue is valid for use as a
395 /// value of `T`.
396 ///
397 /// static T processAsArg(BaseT baseValue);
398 ///
399 /// * This method processes the given base value as a value of `T`.
400 ///
401 template <typename T, typename BaseT>
402 struct ProcessPDLValueBasedOn {
403  static LogicalResult
404  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
405  PDLValue pdlValue, size_t argIdx) {
406  // Verify the base class before continuing.
407  if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
408  return failure();
409  return ProcessPDLValue<T>::verifyAsArg(
410  errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
411  }
412  static T processAsArg(PDLValue pdlValue) {
413  return ProcessPDLValue<T>::processAsArg(
414  ProcessPDLValue<BaseT>::processAsArg(pdlValue));
415  }
416 
417  /// Explicitly add the expected parent API to ensure the parent class
418  /// implements the necessary API (and doesn't implicitly inherit it from
419  /// somewhere else).
420  static LogicalResult
421  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
422  size_t argIdx) {
423  return success();
424  }
425  static T processAsArg(BaseT baseValue);
426 };
427 
428 /// This struct provides a simplified model for processing types that have
429 /// "builtin" PDLValue support:
430 /// * Attribute, Operation *, Type, TypeRange, ValueRange
431 template <typename T>
432 struct ProcessBuiltinPDLValue {
433  static LogicalResult
434  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
435  PDLValue pdlValue, size_t argIdx) {
436  if (pdlValue)
437  return success();
438  return errorFn("expected a non-null value for argument " + Twine(argIdx) +
439  " of type: " + llvm::getTypeName<T>());
440  }
441 
442  static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
443  static void processAsResult(PatternRewriter &, PDLResultList &results,
444  T value) {
445  results.push_back(value);
446  }
447 };
448 
449 /// This struct provides a simplified model for processing types that inherit
450 /// from builtin PDLValue types. For example, derived attributes like
451 /// IntegerAttr, derived types like IntegerType, derived operations like
452 /// ModuleOp, Interfaces, etc.
453 template <typename T, typename BaseT>
454 struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
455  static LogicalResult
456  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
457  BaseT baseValue, size_t argIdx) {
458  return TypeSwitch<BaseT, LogicalResult>(baseValue)
459  .Case([&](T) { return success(); })
460  .Default([&](BaseT) {
461  return errorFn("expected argument " + Twine(argIdx) +
462  " to be of type: " + llvm::getTypeName<T>());
463  });
464  }
465  using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
466 
467  static T processAsArg(BaseT baseValue) {
468  return baseValue.template cast<T>();
469  }
470  using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
471 
472  static void processAsResult(PatternRewriter &, PDLResultList &results,
473  T value) {
474  results.push_back(value);
475  }
476 };
477 
478 //===----------------------------------------------------------------------===//
479 // Attribute
480 
481 template <>
482 struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
483 template <typename T>
484 struct ProcessPDLValue<T,
485  std::enable_if_t<std::is_base_of<Attribute, T>::value>>
486  : public ProcessDerivedPDLValue<T, Attribute> {};
487 
488 /// Handling for various Attribute value types.
489 template <>
490 struct ProcessPDLValue<StringRef>
491  : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
492  static StringRef processAsArg(StringAttr value) { return value.getValue(); }
493  using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
494 
495  static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
496  StringRef value) {
497  results.push_back(rewriter.getStringAttr(value));
498  }
499 };
500 template <>
501 struct ProcessPDLValue<std::string>
502  : public ProcessPDLValueBasedOn<std::string, StringAttr> {
503  template <typename T>
504  static std::string processAsArg(T value) {
505  static_assert(always_false<T>,
506  "`std::string` arguments require a string copy, use "
507  "`StringRef` for string-like arguments instead");
508  return {};
509  }
510  static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
511  StringRef value) {
512  results.push_back(rewriter.getStringAttr(value));
513  }
514 };
515 
516 //===----------------------------------------------------------------------===//
517 // Operation
518 
519 template <>
520 struct ProcessPDLValue<Operation *>
521  : public ProcessBuiltinPDLValue<Operation *> {};
522 template <typename T>
523 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
524  : public ProcessDerivedPDLValue<T, Operation *> {
525  static T processAsArg(Operation *value) { return cast<T>(value); }
526 };
527 
528 //===----------------------------------------------------------------------===//
529 // Type
530 
531 template <>
532 struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
533 template <typename T>
534 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
535  : public ProcessDerivedPDLValue<T, Type> {};
536 
537 //===----------------------------------------------------------------------===//
538 // TypeRange
539 
540 template <>
541 struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
542 template <>
543 struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
544  static void processAsResult(PatternRewriter &, PDLResultList &results,
545  ValueTypeRange<OperandRange> types) {
546  results.push_back(types);
547  }
548 };
549 template <>
550 struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
551  static void processAsResult(PatternRewriter &, PDLResultList &results,
552  ValueTypeRange<ResultRange> types) {
553  results.push_back(types);
554  }
555 };
556 template <unsigned N>
557 struct ProcessPDLValue<SmallVector<Type, N>> {
558  static void processAsResult(PatternRewriter &, PDLResultList &results,
559  SmallVector<Type, N> values) {
560  results.push_back(TypeRange(values));
561  }
562 };
563 
564 //===----------------------------------------------------------------------===//
565 // Value
566 
567 template <>
568 struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
569 
570 //===----------------------------------------------------------------------===//
571 // ValueRange
572 
573 template <>
574 struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
575 };
576 template <>
577 struct ProcessPDLValue<OperandRange> {
578  static void processAsResult(PatternRewriter &, PDLResultList &results,
579  OperandRange values) {
580  results.push_back(values);
581  }
582 };
583 template <>
584 struct ProcessPDLValue<ResultRange> {
585  static void processAsResult(PatternRewriter &, PDLResultList &results,
586  ResultRange values) {
587  results.push_back(values);
588  }
589 };
590 template <unsigned N>
591 struct ProcessPDLValue<SmallVector<Value, N>> {
592  static void processAsResult(PatternRewriter &, PDLResultList &results,
593  SmallVector<Value, N> values) {
594  results.push_back(ValueRange(values));
595  }
596 };
597 
598 //===----------------------------------------------------------------------===//
599 // PDL Function Builder: Argument Handling
600 //===----------------------------------------------------------------------===//
601 
602 /// Validate the given PDLValues match the constraints defined by the argument
603 /// types of the given function. In the case of failure, a match failure
604 /// diagnostic is emitted.
605 /// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
606 /// does not currently preserve Constraint application ordering.
607 template <typename PDLFnT, std::size_t... I>
608 LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
609  std::index_sequence<I...>) {
610  using FnTraitsT = llvm::function_traits<PDLFnT>;
611 
612  auto errorFn = [&](const Twine &msg) {
613  return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
614  };
615  return success(
616  (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
617  verifyAsArg(errorFn, values[I], I)) &&
618  ...));
619 }
620 
621 /// Assert that the given PDLValues match the constraints defined by the
622 /// arguments of the given function. In the case of failure, a fatal error
623 /// is emitted.
624 template <typename PDLFnT, std::size_t... I>
625 void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
626  std::index_sequence<I...>) {
627  // We only want to do verification in debug builds, same as with `assert`.
628 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
629  using FnTraitsT = llvm::function_traits<PDLFnT>;
630  auto errorFn = [&](const Twine &msg) -> LogicalResult {
631  llvm::report_fatal_error(msg);
632  };
633  (void)errorFn;
634  assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
635  verifyAsArg(errorFn, values[I], I)) &&
636  ...));
637 #endif
638  (void)values;
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // PDL Function Builder: Results Handling
643 //===----------------------------------------------------------------------===//
644 
645 /// Store a single result within the result list.
646 template <typename T>
647 static LogicalResult processResults(PatternRewriter &rewriter,
648  PDLResultList &results, T &&value) {
649  ProcessPDLValue<T>::processAsResult(rewriter, results,
650  std::forward<T>(value));
651  return success();
652 }
653 
654 /// Store a std::pair<> as individual results within the result list.
655 template <typename T1, typename T2>
656 static LogicalResult processResults(PatternRewriter &rewriter,
657  PDLResultList &results,
658  std::pair<T1, T2> &&pair) {
659  if (failed(processResults(rewriter, results, std::move(pair.first))) ||
660  failed(processResults(rewriter, results, std::move(pair.second))))
661  return failure();
662  return success();
663 }
664 
665 /// Store a std::tuple<> as individual results within the result list.
666 template <typename... Ts>
667 static LogicalResult processResults(PatternRewriter &rewriter,
668  PDLResultList &results,
669  std::tuple<Ts...> &&tuple) {
670  auto applyFn = [&](auto &&...args) {
671  return (succeeded(processResults(rewriter, results, std::move(args))) &&
672  ...);
673  };
674  return success(std::apply(applyFn, std::move(tuple)));
675 }
676 
677 /// Handle LogicalResult propagation.
678 inline LogicalResult processResults(PatternRewriter &rewriter,
679  PDLResultList &results,
680  LogicalResult &&result) {
681  return result;
682 }
683 template <typename T>
684 static LogicalResult processResults(PatternRewriter &rewriter,
685  PDLResultList &results,
686  FailureOr<T> &&result) {
687  if (failed(result))
688  return failure();
689  return processResults(rewriter, results, std::move(*result));
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // PDL Constraint Builder
694 //===----------------------------------------------------------------------===//
695 
696 /// Process the arguments of a native constraint and invoke it.
697 template <typename PDLFnT, std::size_t... I,
698  typename FnTraitsT = llvm::function_traits<PDLFnT>>
699 typename FnTraitsT::result_t
700 processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
701  ArrayRef<PDLValue> values,
702  std::index_sequence<I...>) {
703  return fn(
704  rewriter,
705  (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
706  values[I]))...);
707 }
708 
709 /// Build a constraint function from the given function `ConstraintFnT`. This
710 /// allows for enabling the user to define simpler, more direct constraint
711 /// functions without needing to handle the low-level PDL goop.
712 ///
713 /// If the constraint function is already in the correct form, we just forward
714 /// it directly.
715 template <typename ConstraintFnT>
716 std::enable_if_t<
717  std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
718  PDLConstraintFunction>
719 buildConstraintFn(ConstraintFnT &&constraintFn) {
720  return std::forward<ConstraintFnT>(constraintFn);
721 }
722 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
723 /// we desire.
724 template <typename ConstraintFnT>
725 std::enable_if_t<
726  !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
727  PDLConstraintFunction>
728 buildConstraintFn(ConstraintFnT &&constraintFn) {
729  return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
730  PatternRewriter &rewriter, PDLResultList &,
731  ArrayRef<PDLValue> values) -> LogicalResult {
732  auto argIndices = std::make_index_sequence<
733  llvm::function_traits<ConstraintFnT>::num_args - 1>();
734  if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
735  return failure();
736  return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
737  argIndices);
738  };
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // PDL Rewrite Builder
743 //===----------------------------------------------------------------------===//
744 
745 /// Process the arguments of a native rewrite and invoke it.
746 /// This overload handles the case of no return values.
747 template <typename PDLFnT, std::size_t... I,
748  typename FnTraitsT = llvm::function_traits<PDLFnT>>
749 std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
750  LogicalResult>
751 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
752  PDLResultList &, ArrayRef<PDLValue> values,
753  std::index_sequence<I...>) {
754  fn(rewriter,
755  (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
756  values[I]))...);
757  return success();
758 }
759 /// This overload handles the case of return values, which need to be packaged
760 /// into the result list.
761 template <typename PDLFnT, std::size_t... I,
762  typename FnTraitsT = llvm::function_traits<PDLFnT>>
763 std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
764  LogicalResult>
765 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
766  PDLResultList &results, ArrayRef<PDLValue> values,
767  std::index_sequence<I...>) {
768  return processResults(
769  rewriter, results,
770  fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
771  processAsArg(values[I]))...));
772  (void)values;
773 }
774 
775 /// Build a rewrite function from the given function `RewriteFnT`. This
776 /// allows for enabling the user to define simpler, more direct rewrite
777 /// functions without needing to handle the low-level PDL goop.
778 ///
779 /// If the rewrite function is already in the correct form, we just forward
780 /// it directly.
781 template <typename RewriteFnT>
782 std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
783  PDLRewriteFunction>
784 buildRewriteFn(RewriteFnT &&rewriteFn) {
785  return std::forward<RewriteFnT>(rewriteFn);
786 }
787 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
788 /// we desire.
789 template <typename RewriteFnT>
790 std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
791  PDLRewriteFunction>
792 buildRewriteFn(RewriteFnT &&rewriteFn) {
793  return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
794  PatternRewriter &rewriter, PDLResultList &results,
795  ArrayRef<PDLValue> values) {
796  auto argIndices =
797  std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
798  1>();
799  assertArgs<RewriteFnT>(rewriter, values, argIndices);
800  return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
801  argIndices);
802  };
803 }
804 
805 } // namespace pdl_function_builder
806 } // namespace detail
807 
808 //===----------------------------------------------------------------------===//
809 // PDLPatternModule
810 
811 /// This class contains all of the necessary data for a set of PDL patterns, or
812 /// pattern rewrites specified in the form of the PDL dialect. This PDL module
813 /// contained by this pattern may contain any number of `pdl.pattern`
814 /// operations.
815 class PDLPatternModule {
816 public:
817  PDLPatternModule() = default;
818 
819  /// Construct a PDL pattern with the given module and configurations.
820  PDLPatternModule(OwningOpRef<ModuleOp> module)
821  : pdlModule(std::move(module)) {}
822  template <typename... ConfigsT>
823  PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
824  : PDLPatternModule(std::move(module)) {
825  auto configSet = std::make_unique<PDLPatternConfigSet>(
826  std::forward<ConfigsT>(patternConfigs)...);
827  attachConfigToPatterns(*pdlModule, *configSet);
828  configs.emplace_back(std::move(configSet));
829  }
830 
831  /// Merge the state in `other` into this pattern module.
832  void mergeIn(PDLPatternModule &&other);
833 
834  /// Return the internal PDL module of this pattern.
835  ModuleOp getModule() { return pdlModule.get(); }
836 
837  /// Return the MLIR context of this pattern.
838  MLIRContext *getContext() { return getModule()->getContext(); }
839 
840  //===--------------------------------------------------------------------===//
841  // Function Registry
842 
843  /// Register a constraint function with PDL. A constraint function may be
844  /// specified in one of two ways:
845  ///
846  /// * `LogicalResult (PatternRewriter &,
847  /// PDLResultList &,
848  /// ArrayRef<PDLValue>)`
849  ///
850  /// In this overload the arguments of the constraint function are passed via
851  /// the low-level PDLValue form, and the results are manually appended to
852  /// the given result list.
853  ///
854  /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
855  ///
856  /// In this form the arguments of the constraint function are passed via the
857  /// expected high level C++ type. In this form, the framework will
858  /// automatically unwrap PDLValues and convert them to the expected ValueTs.
859  /// For example, if the constraint function accepts a `Operation *`, the
860  /// framework will automatically cast the input PDLValue. In the case of a
861  /// `StringRef`, the framework will automatically unwrap the argument as a
862  /// StringAttr and pass the underlying string value. To see the full list of
863  /// supported types, or to see how to add handling for custom types, view
864  /// the definition of `ProcessPDLValue` above.
865  void registerConstraintFunction(StringRef name,
866  PDLConstraintFunction constraintFn);
867  template <typename ConstraintFnT>
868  void registerConstraintFunction(StringRef name,
869  ConstraintFnT &&constraintFn) {
870  registerConstraintFunction(name,
871  detail::pdl_function_builder::buildConstraintFn(
872  std::forward<ConstraintFnT>(constraintFn)));
873  }
874 
875  /// Register a rewrite function with PDL. A rewrite function may be specified
876  /// in one of two ways:
877  ///
878  /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
879  ///
880  /// In this overload the arguments of the constraint function are passed via
881  /// the low-level PDLValue form, and the results are manually appended to
882  /// the given result list.
883  ///
884  /// * `ResultT (PatternRewriter &, ValueTs... values)`
885  ///
886  /// In this form the arguments and result of the rewrite function are passed
887  /// via the expected high level C++ type. In this form, the framework will
888  /// automatically unwrap the PDLValues arguments and convert them to the
889  /// expected ValueTs. It will also automatically handle the processing and
890  /// packaging of the result value to the result list. For example, if the
891  /// rewrite function takes a `Operation *`, the framework will automatically
892  /// cast the input PDLValue. In the case of a `StringRef`, the framework
893  /// will automatically unwrap the argument as a StringAttr and pass the
894  /// underlying string value. In the reverse case, if the rewrite returns a
895  /// StringRef or std::string, it will automatically package this as a
896  /// StringAttr and append it to the result list. To see the full list of
897  /// supported types, or to see how to add handling for custom types, view
898  /// the definition of `ProcessPDLValue` above.
899  void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
900  template <typename RewriteFnT>
901  void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
902  registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
903  std::forward<RewriteFnT>(rewriteFn)));
904  }
905 
906  /// Return the set of the registered constraint functions.
907  const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
908  return constraintFunctions;
909  }
910  llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
911  return constraintFunctions;
912  }
913  /// Return the set of the registered rewrite functions.
914  const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
915  return rewriteFunctions;
916  }
917  llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
918  return rewriteFunctions;
919  }
920 
921  /// Return the set of the registered pattern configs.
922  SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
923  return std::move(configs);
924  }
925  DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
926  return std::move(configMap);
927  }
928 
929  /// Clear out the patterns and functions within this module.
930  void clear() {
931  pdlModule = nullptr;
932  constraintFunctions.clear();
933  rewriteFunctions.clear();
934  }
935 
936 private:
937  /// Attach the given pattern config set to the patterns defined within the
938  /// given module.
939  void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
940 
941  /// The module containing the `pdl.pattern` operations.
942  OwningOpRef<ModuleOp> pdlModule;
943 
944  /// The set of configuration sets referenced by patterns within `pdlModule`.
945  SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
946  DenseMap<Operation *, PDLPatternConfigSet *> configMap;
947 
948  /// The external functions referenced from within the PDL module.
949  llvm::StringMap<PDLConstraintFunction> constraintFunctions;
950  llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
951 };
952 } // namespace mlir
953 
954 #else
955 
956 namespace mlir {
957 // Stubs for when PDL in pattern rewrites is not enabled.
958 
959 class PDLValue {
960 public:
961  template <typename T>
962  T dyn_cast() const {
963  return nullptr;
964  }
965 };
966 class PDLResultList {};
967 using PDLConstraintFunction = std::function<LogicalResult(
968  PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
969 using PDLRewriteFunction = std::function<LogicalResult(
970  PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
971 
972 class PDLPatternModule {
973 public:
974  PDLPatternModule() = default;
975 
976  PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {}
977  MLIRContext *getContext() {
978  llvm_unreachable("Error: PDL for rewrites when PDL is not enabled");
979  }
980  void mergeIn(PDLPatternModule &&other) {}
981  void clear() {}
982  template <typename ConstraintFnT>
983  void registerConstraintFunction(StringRef name,
984  ConstraintFnT &&constraintFn) {}
985  void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
986  template <typename RewriteFnT>
987  void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
988  const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
989  return constraintFunctions;
990  }
991 
992 private:
993  llvm::StringMap<PDLConstraintFunction> constraintFunctions;
994 };
995 
996 } // namespace mlir
997 #endif
998 
999 #endif // MLIR_IR_PDLPATTERNMATCH_H
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Type
An inlay hint that for a type annotation.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:147
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78