MLIR  14.0.0git
PatternMatch.h
Go to the documentation of this file.
1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
16 
17 namespace mlir {
18 
19 class PatternRewriter;
20 
21 //===----------------------------------------------------------------------===//
22 // PatternBenefit class
23 //===----------------------------------------------------------------------===//
24 
25 /// This class represents the benefit of a pattern match in a unitless scheme
26 /// that ranges from 0 (very little benefit) to 65K. The most common unit to
27 /// use here is the "number of operations matched" by the pattern.
28 ///
29 /// This also has a sentinel representation that can be used for patterns that
30 /// fail to match.
31 ///
33  enum { ImpossibleToMatchSentinel = 65535 };
34 
35 public:
36  PatternBenefit() = default;
37  PatternBenefit(unsigned benefit);
38  PatternBenefit(const PatternBenefit &) = default;
39  PatternBenefit &operator=(const PatternBenefit &) = default;
40 
42  bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
43 
44  /// If the corresponding pattern can match, return its benefit. If the
45  // corresponding pattern isImpossibleToMatch() then this aborts.
46  unsigned short getBenefit() const;
47 
48  bool operator==(const PatternBenefit &rhs) const {
49  return representation == rhs.representation;
50  }
51  bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
52  bool operator<(const PatternBenefit &rhs) const {
53  return representation < rhs.representation;
54  }
55  bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
56  bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
57  bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
58 
59 private:
60  unsigned short representation{ImpossibleToMatchSentinel};
61 };
62 
63 //===----------------------------------------------------------------------===//
64 // Pattern
65 //===----------------------------------------------------------------------===//
66 
67 /// This class contains all of the data related to a pattern, but does not
68 /// contain any methods or logic for the actual matching. This class is solely
69 /// used to interface with the metadata of a pattern, such as the benefit or
70 /// root operation.
71 class Pattern {
72  /// This enum represents the kind of value used to select the root operations
73  /// that match this pattern.
74  enum class RootKind {
75  /// The pattern root matches "any" operation.
76  Any,
77  /// The pattern root is matched using a concrete operation name.
79  /// The pattern root is matched using an interface ID.
80  InterfaceID,
81  /// The patter root is matched using a trait ID.
82  TraitID
83  };
84 
85 public:
86  /// Return a list of operations that may be generated when rewriting an
87  /// operation instance with this pattern.
88  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
89 
90  /// Return the root node that this pattern matches. Patterns that can match
91  /// multiple root types return None.
93  if (rootKind == RootKind::OperationName)
94  return OperationName::getFromOpaquePointer(rootValue);
95  return llvm::None;
96  }
97 
98  /// Return the interface ID used to match the root operation of this pattern.
99  /// If the pattern does not use an interface ID for deciding the root match,
100  /// this returns None.
102  if (rootKind == RootKind::InterfaceID)
103  return TypeID::getFromOpaquePointer(rootValue);
104  return llvm::None;
105  }
106 
107  /// Return the trait ID used to match the root operation of this pattern.
108  /// If the pattern does not use a trait ID for deciding the root match, this
109  /// returns None.
111  if (rootKind == RootKind::TraitID)
112  return TypeID::getFromOpaquePointer(rootValue);
113  return llvm::None;
114  }
115 
116  /// Return the benefit (the inverse of "cost") of matching this pattern. The
117  /// benefit of a Pattern is always static - rewrites that may have dynamic
118  /// benefit can be instantiated multiple times (different Pattern instances)
119  /// for each benefit that they may return, and be guarded by different match
120  /// condition predicates.
121  PatternBenefit getBenefit() const { return benefit; }
122 
123  /// Returns true if this pattern is known to result in recursive application,
124  /// i.e. this pattern may generate IR that also matches this pattern, but is
125  /// known to bound the recursion. This signals to a rewrite driver that it is
126  /// safe to apply this pattern recursively to generated IR.
128  return contextAndHasBoundedRecursion.getInt();
129  }
130 
131  /// Return the MLIRContext used to create this pattern.
133  return contextAndHasBoundedRecursion.getPointer();
134  }
135 
136  /// Return a readable name for this pattern. This name should only be used for
137  /// debugging purposes, and may be empty.
138  StringRef getDebugName() const { return debugName; }
139 
140  /// Set the human readable debug name used for this pattern. This name will
141  /// only be used for debugging purposes.
142  void setDebugName(StringRef name) { debugName = name; }
143 
144  /// Return the set of debug labels attached to this pattern.
145  ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
146 
147  /// Add the provided debug labels to this pattern.
149  debugLabels.append(labels.begin(), labels.end());
150  }
151  void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
152 
153 protected:
154  /// This class acts as a special tag that makes the desire to match "any"
155  /// operation type explicit. This helps to avoid unnecessary usages of this
156  /// feature, and ensures that the user is making a conscious decision.
157  struct MatchAnyOpTypeTag {};
158  /// This class acts as a special tag that makes the desire to match any
159  /// operation that implements a given interface explicit. This helps to avoid
160  /// unnecessary usages of this feature, and ensures that the user is making a
161  /// conscious decision.
163  /// This class acts as a special tag that makes the desire to match any
164  /// operation that implements a given trait explicit. This helps to avoid
165  /// unnecessary usages of this feature, and ensures that the user is making a
166  /// conscious decision.
168 
169  /// Construct a pattern with a certain benefit that matches the operation
170  /// with the given root name.
171  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
172  ArrayRef<StringRef> generatedNames = {});
173  /// Construct a pattern that may match any operation type. `generatedNames`
174  /// contains the names of operations that may be generated during a successful
175  /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
176  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
177  /// always be supplied here.
178  Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
179  ArrayRef<StringRef> generatedNames = {});
180  /// Construct a pattern that may match any operation that implements the
181  /// interface defined by the provided `interfaceID`. `generatedNames` contains
182  /// the names of operations that may be generated during a successful rewrite.
183  /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
184  /// interface" behavior is what the user actually desired,
185  /// `MatchInterfaceOpTypeTag()` should always be supplied here.
186  Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
187  PatternBenefit benefit, MLIRContext *context,
188  ArrayRef<StringRef> generatedNames = {});
189  /// Construct a pattern that may match any operation that implements the
190  /// trait defined by the provided `traitID`. `generatedNames` contains the
191  /// names of operations that may be generated during a successful rewrite.
192  /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
193  /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
194  /// always be supplied here.
195  Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
196  MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
197 
198  /// Set the flag detailing if this pattern has bounded rewrite recursion or
199  /// not.
200  void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
201  contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
202  }
203 
204 private:
205  Pattern(const void *rootValue, RootKind rootKind,
206  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
207  MLIRContext *context);
208 
209  /// The value used to match the root operation of the pattern.
210  const void *rootValue;
211  RootKind rootKind;
212 
213  /// The expected benefit of matching this pattern.
214  const PatternBenefit benefit;
215 
216  /// The context this pattern was created from, and a boolean flag indicating
217  /// whether this pattern has bounded recursion or not.
218  llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
219 
220  /// A list of the potential operations that may be generated when rewriting
221  /// an op with this pattern.
222  SmallVector<OperationName, 2> generatedOps;
223 
224  /// A readable name for this pattern. May be empty.
225  StringRef debugName;
226 
227  /// The set of debug labels attached to this pattern.
228  SmallVector<StringRef, 0> debugLabels;
229 };
230 
231 //===----------------------------------------------------------------------===//
232 // RewritePattern
233 //===----------------------------------------------------------------------===//
234 
235 /// RewritePattern is the common base class for all DAG to DAG replacements.
236 /// There are two possible usages of this class:
237 /// * Multi-step RewritePattern with "match" and "rewrite"
238 /// - By overloading the "match" and "rewrite" functions, the user can
239 /// separate the concerns of matching and rewriting.
240 /// * Single-step RewritePattern with "matchAndRewrite"
241 /// - By overloading the "matchAndRewrite" function, the user can perform
242 /// the rewrite in the same call as the match.
243 ///
244 class RewritePattern : public Pattern {
245 public:
246  virtual ~RewritePattern() = default;
247 
248  /// Rewrite the IR rooted at the specified operation with the result of
249  /// this pattern, generating any new operations with the specified
250  /// builder. If an unexpected error is encountered (an internal
251  /// compiler error), it is emitted through the normal MLIR diagnostic
252  /// hooks and the IR is left in a valid state.
253  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
254 
255  /// Attempt to match against code rooted at the specified operation,
256  /// which is the same operation code as getRootKind().
257  virtual LogicalResult match(Operation *op) const;
258 
259  /// Attempt to match against code rooted at the specified operation,
260  /// which is the same operation code as getRootKind(). If successful, this
261  /// function will automatically perform the rewrite.
263  PatternRewriter &rewriter) const {
264  if (succeeded(match(op))) {
265  rewrite(op, rewriter);
266  return success();
267  }
268  return failure();
269  }
270 
271  /// This method provides a convenient interface for creating and initializing
272  /// derived rewrite patterns of the given type `T`.
273  template <typename T, typename... Args>
274  static std::unique_ptr<T> create(Args &&... args) {
275  std::unique_ptr<T> pattern =
276  std::make_unique<T>(std::forward<Args>(args)...);
277  initializePattern<T>(*pattern);
278 
279  // Set a default debug name if one wasn't provided.
280  if (pattern->getDebugName().empty())
281  pattern->setDebugName(llvm::getTypeName<T>());
282  return pattern;
283  }
284 
285 protected:
286  /// Inherit the base constructors from `Pattern`.
287  using Pattern::Pattern;
288 
289 private:
290  /// Trait to check if T provides a `getOperationName` method.
291  template <typename T, typename... Args>
292  using has_initialize = decltype(std::declval<T>().initialize());
293  template <typename T>
294  using detect_has_initialize = llvm::is_detected<has_initialize, T>;
295 
296  /// Initialize the derived pattern by calling its `initialize` method.
297  template <typename T>
299  initializePattern(T &pattern) {
300  pattern.initialize();
301  }
302  /// Empty derived pattern initializer for patterns that do not have an
303  /// initialize method.
304  template <typename T>
306  initializePattern(T &) {}
307 
308  /// An anchor for the virtual table.
309  virtual void anchor();
310 };
311 
312 namespace detail {
313 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
314 /// allows for matching and rewriting against an instance of a derived operation
315 /// class or Interface.
316 template <typename SourceOp>
318  using RewritePattern::RewritePattern;
319 
320  /// Wrappers around the RewritePattern methods that pass the derived op type.
321  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
322  rewrite(cast<SourceOp>(op), rewriter);
323  }
324  LogicalResult match(Operation *op) const final {
325  return match(cast<SourceOp>(op));
326  }
328  PatternRewriter &rewriter) const final {
329  return matchAndRewrite(cast<SourceOp>(op), rewriter);
330  }
331 
332  /// Rewrite and Match methods that operate on the SourceOp type. These must be
333  /// overridden by the derived pattern class.
334  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
335  llvm_unreachable("must override rewrite or matchAndRewrite");
336  }
337  virtual LogicalResult match(SourceOp op) const {
338  llvm_unreachable("must override match or matchAndRewrite");
339  }
340  virtual LogicalResult matchAndRewrite(SourceOp op,
341  PatternRewriter &rewriter) const {
342  if (succeeded(match(op))) {
343  rewrite(op, rewriter);
344  return success();
345  }
346  return failure();
347  }
348 };
349 } // namespace detail
350 
351 /// OpRewritePattern is a wrapper around RewritePattern that allows for
352 /// matching and rewriting against an instance of a derived operation class as
353 /// opposed to a raw Operation.
354 template <typename SourceOp>
356  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
357  /// Patterns must specify the root operation name they match against, and can
358  /// also specify the benefit of the pattern matching and a list of generated
359  /// ops.
361  ArrayRef<StringRef> generatedNames = {})
363  SourceOp::getOperationName(), benefit, context, generatedNames) {}
364 };
365 
366 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
367 /// matching and rewriting against an instance of an operation interface instead
368 /// of a raw Operation.
369 template <typename SourceOp>
371  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
373  : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
374  Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
375  benefit, context) {}
376 };
377 
378 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
379 /// matching and rewriting against instances of an operation that possess a
380 /// given trait.
381 template <template <typename> class TraitType>
383 public:
385  : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
386  benefit, context) {}
387 };
388 
389 //===----------------------------------------------------------------------===//
390 // PDLPatternModule
391 //===----------------------------------------------------------------------===//
392 
393 //===----------------------------------------------------------------------===//
394 // PDLValue
395 
396 /// Storage type of byte-code interpreter values. These are passed to constraint
397 /// functions as arguments.
398 class PDLValue {
399 public:
400  /// The underlying kind of a PDL value.
402 
403  /// Construct a new PDL value.
404  PDLValue(const PDLValue &other) = default;
405  PDLValue(std::nullptr_t = nullptr) {}
407  : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
408  PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
409  PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
410  PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
412  : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
413  PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
414 
415  /// Returns true if the type of the held value is `T`.
416  template <typename T>
417  bool isa() const {
418  assert(value && "isa<> used on a null value");
419  return kind == getKindOf<T>();
420  }
421 
422  /// Attempt to dynamically cast this value to type `T`, returns null if this
423  /// value is not an instance of `T`.
424  template <typename T,
425  typename ResultT = std::conditional_t<
427  ResultT dyn_cast() const {
428  return isa<T>() ? castImpl<T>() : ResultT();
429  }
430 
431  /// Cast this value to type `T`, asserts if this value is not an instance of
432  /// `T`.
433  template <typename T>
434  T cast() const {
435  assert(isa<T>() && "expected value to be of type `T`");
436  return castImpl<T>();
437  }
438 
439  /// Get an opaque pointer to the value.
440  const void *getAsOpaquePointer() const { return value; }
441 
442  /// Return if this value is null or not.
443  explicit operator bool() const { return value; }
444 
445  /// Return the kind of this value.
446  Kind getKind() const { return kind; }
447 
448  /// Print this value to the provided output stream.
449  void print(raw_ostream &os) const;
450 
451  /// Print the specified value kind to an output stream.
452  static void print(raw_ostream &os, Kind kind);
453 
454 private:
455  /// Find the index of a given type in a range of other types.
456  template <typename...>
457  struct index_of_t;
458  template <typename T, typename... R>
459  struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
460  template <typename T, typename F, typename... R>
461  struct index_of_t<T, F, R...>
462  : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
463 
464  /// Return the kind used for the given T.
465  template <typename T>
466  static Kind getKindOf() {
467  return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
468  TypeRange, Value, ValueRange>::value);
469  }
470 
471  /// The internal implementation of `cast`, that returns the underlying value
472  /// as the given type `T`.
473  template <typename T>
475  castImpl() const {
476  return T::getFromOpaquePointer(value);
477  }
478  template <typename T>
480  castImpl() const {
481  return *reinterpret_cast<T *>(const_cast<void *>(value));
482  }
483  template <typename T>
484  std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
485  return reinterpret_cast<T>(const_cast<void *>(value));
486  }
487 
488  /// The internal opaque representation of a PDLValue.
489  const void *value{nullptr};
490  /// The kind of the opaque value.
491  Kind kind{Kind::Attribute};
492 };
493 
494 inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
495  value.print(os);
496  return os;
497 }
498 
499 inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
500  PDLValue::print(os, kind);
501  return os;
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // PDLResultList
506 
507 /// The class represents a list of PDL results, returned by a native rewrite
508 /// method. It provides the mechanism with which to pass PDLValues back to the
509 /// PDL bytecode.
511 public:
512  /// Push a new Attribute value onto the result list.
513  void push_back(Attribute value) { results.push_back(value); }
514 
515  /// Push a new Operation onto the result list.
516  void push_back(Operation *value) { results.push_back(value); }
517 
518  /// Push a new Type onto the result list.
519  void push_back(Type value) { results.push_back(value); }
520 
521  /// Push a new TypeRange onto the result list.
523  // The lifetime of a TypeRange can't be guaranteed, so we'll need to
524  // allocate a storage for it.
525  llvm::OwningArrayRef<Type> storage(value.size());
526  llvm::copy(value, storage.begin());
527  allocatedTypeRanges.emplace_back(std::move(storage));
528  typeRanges.push_back(allocatedTypeRanges.back());
529  results.push_back(&typeRanges.back());
530  }
532  typeRanges.push_back(value);
533  results.push_back(&typeRanges.back());
534  }
536  typeRanges.push_back(value);
537  results.push_back(&typeRanges.back());
538  }
539 
540  /// Push a new Value onto the result list.
541  void push_back(Value value) { results.push_back(value); }
542 
543  /// Push a new ValueRange onto the result list.
545  // The lifetime of a ValueRange can't be guaranteed, so we'll need to
546  // allocate a storage for it.
547  llvm::OwningArrayRef<Value> storage(value.size());
548  llvm::copy(value, storage.begin());
549  allocatedValueRanges.emplace_back(std::move(storage));
550  valueRanges.push_back(allocatedValueRanges.back());
551  results.push_back(&valueRanges.back());
552  }
554  valueRanges.push_back(value);
555  results.push_back(&valueRanges.back());
556  }
558  valueRanges.push_back(value);
559  results.push_back(&valueRanges.back());
560  }
561 
562 protected:
563  /// Create a new result list with the expected number of results.
564  PDLResultList(unsigned maxNumResults) {
565  // For now just reserve enough space for all of the results. We could do
566  // separate counts per range type, but it isn't really worth it unless there
567  // are a "large" number of results.
568  typeRanges.reserve(maxNumResults);
569  valueRanges.reserve(maxNumResults);
570  }
571 
572  /// The PDL results held by this list.
574  /// Memory used to store ranges held by the list.
577  /// Memory allocated to store ranges in the result list whose lifetime was
578  /// generated in the native function.
581 };
582 
583 //===----------------------------------------------------------------------===//
584 // PDLPatternModule
585 
586 /// A generic PDL pattern constraint function. This function applies a
587 /// constraint to a given set of opaque PDLValue entities. The second parameter
588 /// is a set of constant value parameters specified in Attribute form. Returns
589 /// success if the constraint successfully held, failure otherwise.
590 using PDLConstraintFunction = std::function<LogicalResult(
592 /// A native PDL rewrite function. This function performs a rewrite on the
593 /// given set of values and constant parameters. Any results from this rewrite
594 /// that should be passed back to PDL should be added to the provided result
595 /// list. This method is only invoked when the corresponding match was
596 /// successful.
597 using PDLRewriteFunction = std::function<void(
599 /// A generic PDL pattern constraint function. This function applies a
600 /// constraint to a given opaque PDLValue entity. The second parameter is a set
601 /// of constant value parameters specified in Attribute form. Returns success if
602 /// the constraint successfully held, failure otherwise.
604  std::function<LogicalResult(PDLValue, ArrayAttr, PatternRewriter &)>;
605 
606 /// This class contains all of the necessary data for a set of PDL patterns, or
607 /// pattern rewrites specified in the form of the PDL dialect. This PDL module
608 /// contained by this pattern may contain any number of `pdl.pattern`
609 /// operations.
611 public:
612  PDLPatternModule() = default;
613 
614  /// Construct a PDL pattern with the given module.
616  : pdlModule(std::move(pdlModule)) {}
617 
618  /// Merge the state in `other` into this pattern module.
619  void mergeIn(PDLPatternModule &&other);
620 
621  /// Return the internal PDL module of this pattern.
622  ModuleOp getModule() { return pdlModule.get(); }
623 
624  //===--------------------------------------------------------------------===//
625  // Function Registry
626 
627  /// Register a constraint function.
628  void registerConstraintFunction(StringRef name,
629  PDLConstraintFunction constraintFn);
630  /// Register a single entity constraint function.
631  template <typename SingleEntityFn>
632  std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
633  ArrayAttr, PatternRewriter &>::value>
634  registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
635  registerConstraintFunction(
636  name, [constraintFn = std::forward<SingleEntityFn>(constraintFn)](
637  ArrayRef<PDLValue> values, ArrayAttr constantParams,
638  PatternRewriter &rewriter) {
639  assert(values.size() == 1 &&
640  "expected values to have a single entity");
641  return constraintFn(values[0], constantParams, rewriter);
642  });
643  }
644 
645  /// Register a rewrite function.
646  void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
647 
648  /// Return the set of the registered constraint functions.
649  const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
650  return constraintFunctions;
651  }
652  llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
653  return constraintFunctions;
654  }
655  /// Return the set of the registered rewrite functions.
656  const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
657  return rewriteFunctions;
658  }
659  llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
660  return rewriteFunctions;
661  }
662 
663  /// Clear out the patterns and functions within this module.
664  void clear() {
665  pdlModule = nullptr;
666  constraintFunctions.clear();
667  rewriteFunctions.clear();
668  }
669 
670 private:
671  /// The module containing the `pdl.pattern` operations.
672  OwningModuleRef pdlModule;
673 
674  /// The external functions referenced from within the PDL module.
675  llvm::StringMap<PDLConstraintFunction> constraintFunctions;
676  llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
677 };
678 
679 //===----------------------------------------------------------------------===//
680 // RewriterBase
681 //===----------------------------------------------------------------------===//
682 
683 /// This class coordinates the application of a rewrite on a set of IR,
684 /// providing a way for clients to track mutations and create new operations.
685 /// This class serves as a common API for IR mutation between pattern rewrites
686 /// and non-pattern rewrites, and facilitates the development of shared
687 /// IR transformation utilities.
689 public:
690  /// Move the blocks that belong to "region" before the given position in
691  /// another region "parent". The two regions must be different. The caller
692  /// is responsible for creating or updating the operation transferring flow
693  /// of control to the region and passing it the correct block arguments.
694  virtual void inlineRegionBefore(Region &region, Region &parent,
695  Region::iterator before);
696  void inlineRegionBefore(Region &region, Block *before);
697 
698  /// Clone the blocks that belong to "region" before the given position in
699  /// another region "parent". The two regions must be different. The caller is
700  /// responsible for creating or updating the operation transferring flow of
701  /// control to the region and passing it the correct block arguments.
702  virtual void cloneRegionBefore(Region &region, Region &parent,
703  Region::iterator before,
704  BlockAndValueMapping &mapping);
705  void cloneRegionBefore(Region &region, Region &parent,
706  Region::iterator before);
707  void cloneRegionBefore(Region &region, Block *before);
708 
709  /// This method replaces the uses of the results of `op` with the values in
710  /// `newValues` when the provided `functor` returns true for a specific use.
711  /// The number of values in `newValues` is required to match the number of
712  /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
713  /// the uses of `op` were replaced. Note that in some rewriters, the given
714  /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
715  /// As such, the function should not capture by reference and instead use
716  /// value capture as necessary.
717  virtual void
718  replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
719  llvm::unique_function<bool(OpOperand &) const> functor);
720  void replaceOpWithIf(Operation *op, ValueRange newValues,
721  llvm::unique_function<bool(OpOperand &) const> functor) {
722  replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
723  std::move(functor));
724  }
725 
726  /// This method replaces the uses of the results of `op` with the values in
727  /// `newValues` when a use is nested within the given `block`. The number of
728  /// values in `newValues` is required to match the number of results of `op`.
729  /// If all uses of this operation are replaced, the operation is erased.
730  void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
731  bool *allUsesReplaced = nullptr);
732 
733  /// This method replaces the results of the operation with the specified list
734  /// of values. The number of provided values must match the number of results
735  /// of the operation.
736  virtual void replaceOp(Operation *op, ValueRange newValues);
737 
738  /// Replaces the result op with a new op that is created without verification.
739  /// The result values of the two ops must be the same types.
740  template <typename OpTy, typename... Args>
741  OpTy replaceOpWithNewOp(Operation *op, Args &&... args) {
742  auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
743  replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
744  return newOp;
745  }
746 
747  /// This method erases an operation that is known to have no uses.
748  virtual void eraseOp(Operation *op);
749 
750  /// This method erases all operations in a block.
751  virtual void eraseBlock(Block *block);
752 
753  /// Merge the operations of block 'source' into the end of block 'dest'.
754  /// 'source's predecessors must either be empty or only contain 'dest`.
755  /// 'argValues' is used to replace the block arguments of 'source' after
756  /// merging.
757  virtual void mergeBlocks(Block *source, Block *dest,
758  ValueRange argValues = llvm::None);
759 
760  // Merge the operations of block 'source' before the operation 'op'. Source
761  // block should not have existing predecessors or successors.
762  void mergeBlockBefore(Block *source, Operation *op,
763  ValueRange argValues = llvm::None);
764 
765  /// Split the operations starting at "before" (inclusive) out of the given
766  /// block into a new block, and return it.
767  virtual Block *splitBlock(Block *block, Block::iterator before);
768 
769  /// This method is used to notify the rewriter that an in-place operation
770  /// modification is about to happen. A call to this function *must* be
771  /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
772  /// This is a minor efficiency win (it avoids creating a new operation and
773  /// removing the old one) but also often allows simpler code in the client.
774  virtual void startRootUpdate(Operation *op) {}
775 
776  /// This method is used to signal the end of a root update on the given
777  /// operation. This can only be called on operations that were provided to a
778  /// call to `startRootUpdate`.
779  virtual void finalizeRootUpdate(Operation *op) {}
780 
781  /// This method cancels a pending root update. This can only be called on
782  /// operations that were provided to a call to `startRootUpdate`.
783  virtual void cancelRootUpdate(Operation *op) {}
784 
785  /// This method is a utility wrapper around a root update of an operation. It
786  /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
787  /// callable.
788  template <typename CallableT>
789  void updateRootInPlace(Operation *root, CallableT &&callable) {
790  startRootUpdate(root);
791  callable();
792  finalizeRootUpdate(root);
793  }
794 
795  /// Used to notify the rewriter that the IR failed to be rewritten because of
796  /// a match failure, and provide a callback to populate a diagnostic with the
797  /// reason why the failure occurred. This method allows for derived rewriters
798  /// to optionally hook into the reason why a rewrite failed, and display it to
799  /// users.
800  template <typename CallbackT>
802  notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
803 #ifndef NDEBUG
804  return notifyMatchFailure(op,
805  function_ref<void(Diagnostic &)>(reasonCallback));
806 #else
807  return failure();
808 #endif
809  }
810  LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) {
811  return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; });
812  }
814  return notifyMatchFailure(op, Twine(msg));
815  }
816 
817 protected:
818  /// Initialize the builder with this rewriter as the listener.
819  explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
820  explicit RewriterBase(const OpBuilder &otherBuilder)
821  : OpBuilder(otherBuilder) {
822  setListener(this);
823  }
824  ~RewriterBase() override;
825 
826  /// These are the callback methods that subclasses can choose to implement if
827  /// they would like to be notified about certain types of mutations.
828 
829  /// Notify the rewriter that the specified operation is about to be replaced
830  /// with another set of operations. This is called before the uses of the
831  /// operation have been changed.
832  virtual void notifyRootReplaced(Operation *op) {}
833 
834  /// This is called on an operation that a rewrite is removing, right before
835  /// the operation is deleted. At this point, the operation has zero uses.
836  virtual void notifyOperationRemoved(Operation *op) {}
837 
838  /// Notify the rewriter that the pattern failed to match the given operation,
839  /// and provide a callback to populate a diagnostic with the reason why the
840  /// failure occurred. This method allows for derived rewriters to optionally
841  /// hook into the reason why a rewrite failed, and display it to users.
842  virtual LogicalResult
844  function_ref<void(Diagnostic &)> reasonCallback) {
845  return failure();
846  }
847 
848 private:
849  void operator=(const RewriterBase &) = delete;
850  RewriterBase(const RewriterBase &) = delete;
851 
852  /// 'op' and 'newOp' are known to have the same number of results, replace the
853  /// uses of op with uses of newOp.
854  void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
855 };
856 
857 //===----------------------------------------------------------------------===//
858 // IRRewriter
859 //===----------------------------------------------------------------------===//
860 
861 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
862 /// providing a way to keep track of the mutations made to the IR. This class
863 /// should only be used in situations where another `RewriterBase` instance,
864 /// such as a `PatternRewriter`, is not available.
865 class IRRewriter : public RewriterBase {
866 public:
867  explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
868  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
869 };
870 
871 //===----------------------------------------------------------------------===//
872 // PatternRewriter
873 //===----------------------------------------------------------------------===//
874 
875 /// A special type of `RewriterBase` that coordinates the application of a
876 /// rewrite pattern on the current IR being matched, providing a way to keep
877 /// track of any mutations made. This class should be used to perform all
878 /// necessary IR mutations within a rewrite pattern, as the pattern driver may
879 /// be tracking various state that would be invalidated when a mutation takes
880 /// place.
882 public:
884 };
885 
886 //===----------------------------------------------------------------------===//
887 // RewritePatternSet
888 //===----------------------------------------------------------------------===//
889 
891  using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
892 
893 public:
894  RewritePatternSet(MLIRContext *context) : context(context) {}
895 
896  /// Construct a RewritePatternSet populated with the given pattern.
898  std::unique_ptr<RewritePattern> pattern)
899  : context(context) {
900  nativePatterns.emplace_back(std::move(pattern));
901  }
903  : context(pattern.getModule()->getContext()),
904  pdlPatterns(std::move(pattern)) {}
905 
906  MLIRContext *getContext() const { return context; }
907 
908  /// Return the native patterns held in this list.
909  NativePatternListT &getNativePatterns() { return nativePatterns; }
910 
911  /// Return the PDL patterns held in this list.
912  PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
913 
914  /// Clear out all of the held patterns in this list.
915  void clear() {
916  nativePatterns.clear();
917  pdlPatterns.clear();
918  }
919 
920  //===--------------------------------------------------------------------===//
921  // 'add' methods for adding patterns to the set.
922  //===--------------------------------------------------------------------===//
923 
924  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
925  /// the given arguments. Return a reference to `this` for chaining insertions.
926  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
927  template <typename... Ts, typename ConstructorArg,
928  typename... ConstructorArgs,
929  typename = std::enable_if_t<sizeof...(Ts) != 0>>
930  RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&... args) {
931  // The following expands a call to emplace_back for each of the pattern
932  // types 'Ts'. This magic is necessary due to a limitation in the places
933  // that a parameter pack can be expanded in c++11.
934  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
935  (void)std::initializer_list<int>{
936  0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
937  return *this;
938  }
939  /// An overload of the above `add` method that allows for attaching a set
940  /// of debug labels to the attached patterns. This is useful for labeling
941  /// groups of patterns that may be shared between multiple different
942  /// passes/users.
943  template <typename... Ts, typename ConstructorArg,
944  typename... ConstructorArgs,
945  typename = std::enable_if_t<sizeof...(Ts) != 0>>
947  ConstructorArg &&arg,
948  ConstructorArgs &&... args) {
949  // The following expands a call to emplace_back for each of the pattern
950  // types 'Ts'. This magic is necessary due to a limitation in the places
951  // that a parameter pack can be expanded in c++11.
952  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
953  (void)std::initializer_list<int>{
954  0, (addImpl<Ts>(debugLabels, arg, args...), 0)...};
955  return *this;
956  }
957 
958  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
959  /// `this` for chaining insertions.
960  template <typename... Ts>
962  (void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...};
963  return *this;
964  }
965 
966  /// Add the given native pattern to the pattern list. Return a reference to
967  /// `this` for chaining insertions.
968  RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
969  nativePatterns.emplace_back(std::move(pattern));
970  return *this;
971  }
972 
973  /// Add the given PDL pattern to the pattern list. Return a reference to
974  /// `this` for chaining insertions.
976  pdlPatterns.mergeIn(std::move(pattern));
977  return *this;
978  }
979 
980  // Add a matchAndRewrite style pattern represented as a C function pointer.
981  template <typename OpType>
983  PatternRewriter &rewriter)) {
984  struct FnPattern final : public OpRewritePattern<OpType> {
985  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
986  MLIRContext *context)
987  : OpRewritePattern<OpType>(context), implFn(implFn) {}
988 
989  LogicalResult matchAndRewrite(OpType op,
990  PatternRewriter &rewriter) const override {
991  return implFn(op, rewriter);
992  }
993 
994  private:
995  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
996  };
997  add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
998  return *this;
999  }
1000 
1001  //===--------------------------------------------------------------------===//
1002  // Pattern Insertion
1003  //===--------------------------------------------------------------------===//
1004 
1005  // TODO: These are soft deprecated in favor of the 'add' methods above.
1006 
1007  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
1008  /// the given arguments. Return a reference to `this` for chaining insertions.
1009  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
1010  template <typename... Ts, typename ConstructorArg,
1011  typename... ConstructorArgs,
1012  typename = std::enable_if_t<sizeof...(Ts) != 0>>
1013  RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
1014  // The following expands a call to emplace_back for each of the pattern
1015  // types 'Ts'. This magic is necessary due to a limitation in the places
1016  // that a parameter pack can be expanded in c++11.
1017  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
1018  (void)std::initializer_list<int>{
1019  0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
1020  return *this;
1021  }
1022 
1023  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
1024  /// `this` for chaining insertions.
1025  template <typename... Ts>
1027  (void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...};
1028  return *this;
1029  }
1030 
1031  /// Add the given native pattern to the pattern list. Return a reference to
1032  /// `this` for chaining insertions.
1033  RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
1034  nativePatterns.emplace_back(std::move(pattern));
1035  return *this;
1036  }
1037 
1038  /// Add the given PDL pattern to the pattern list. Return a reference to
1039  /// `this` for chaining insertions.
1041  pdlPatterns.mergeIn(std::move(pattern));
1042  return *this;
1043  }
1044 
1045  // Add a matchAndRewrite style pattern represented as a C function pointer.
1046  template <typename OpType>
1048  insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
1049  struct FnPattern final : public OpRewritePattern<OpType> {
1050  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1051  MLIRContext *context)
1052  : OpRewritePattern<OpType>(context), implFn(implFn) {
1053  this->setDebugName(llvm::getTypeName<FnPattern>());
1054  }
1055 
1056  LogicalResult matchAndRewrite(OpType op,
1057  PatternRewriter &rewriter) const override {
1058  return implFn(op, rewriter);
1059  }
1060 
1061  private:
1062  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1063  };
1064  insert(std::make_unique<FnPattern>(std::move(implFn), getContext()));
1065  return *this;
1066  }
1067 
1068 private:
1069  /// Add an instance of the pattern type 'T'. Return a reference to `this` for
1070  /// chaining insertions.
1071  template <typename T, typename... Args>
1073  addImpl(ArrayRef<StringRef> debugLabels, Args &&... args) {
1074  std::unique_ptr<T> pattern =
1075  RewritePattern::create<T>(std::forward<Args>(args)...);
1076  pattern->addDebugLabels(debugLabels);
1077  nativePatterns.emplace_back(std::move(pattern));
1078  }
1079  template <typename T, typename... Args>
1081  addImpl(ArrayRef<StringRef> debugLabels, Args &&... args) {
1082  // TODO: Add the provided labels to the PDL pattern when PDL supports
1083  // labels.
1084  pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1085  }
1086 
1087  MLIRContext *const context;
1088  NativePatternListT nativePatterns;
1089  PDLPatternModule pdlPatterns;
1090 };
1091 
1092 } // namespace mlir
1093 
1094 #endif // MLIR_IR_PATTERNMATCH_H
T cast() const
Cast this value to type T, asserts if this value is not an instance of T.
Definition: PatternMatch.h:434
bool operator<(const PatternBenefit &rhs) const
Definition: PatternMatch.h:52
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
PDLValue(TypeRange *value)
Definition: PatternMatch.h:410
static std::string diag(llvm::Value &v)
void push_back(ResultRange value)
Definition: PatternMatch.h:557
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Definition: PatternMatch.h:101
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:162
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: PatternMatch.h:334
bool isa() const
Returns true if the type of the held value is T.
Definition: PatternMatch.h:417
ResultT dyn_cast() const
Attempt to dynamically cast this value to type T, returns null if this value is not an instance of T...
Definition: PatternMatch.h:427
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
void push_back(OperandRange value)
Definition: PatternMatch.h:553
const llvm::StringMap< PDLRewriteFunction > & getRewriteFunctions() const
Return the set of the registered rewrite functions.
Definition: PatternMatch.h:656
SmallVector< PDLValue > results
The PDL results held by this list.
Definition: PatternMatch.h:573
Block represents an ordered list of Operations.
Definition: Block.h:29
PDLResultList(unsigned maxNumResults)
Create a new result list with the expected number of results.
Definition: PatternMatch.h:564
PDLValue(Type value)
Definition: PatternMatch.h:409
IRRewriter(const OpBuilder &builder)
Definition: PatternMatch.h:868
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
Definition: PatternMatch.h:975
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
BlockListType::iterator iterator
Definition: Region.h:52
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:372
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:384
void print(OpAsmPrinter &p, FunctionLibraryOp op)
Definition: Shape.cpp:1112
std::function< LogicalResult(PDLValue, ArrayAttr, PatternRewriter &)> PDLSingleEntityConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:604
This class implements the result iterators for the Operation class.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
void push_back(Operation *value)
Push a new Operation onto the result list.
Definition: PatternMatch.h:516
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:324
llvm::StringMap< PDLRewriteFunction > takeRewriteFunctions()
Definition: PatternMatch.h:659
PatternBenefit & operator=(const PatternBenefit &)=default
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:774
const void * getAsOpaquePointer() const
Get an opaque pointer to the value.
Definition: PatternMatch.h:440
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: PatternMatch.h:321
PDLValue(Value value)
Definition: PatternMatch.h:411
This class represents a listener that may be used to hook into various actions within an OpBuilder...
Definition: Builders.h:234
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
Definition: PatternMatch.h:897
PDLValue(Attribute value)
Definition: PatternMatch.h:406
static constexpr const bool value
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:52
void addDebugLabels(StringRef label)
Definition: PatternMatch.h:151
static OperationName getFromOpaquePointer(const void *pointer)
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
Definition: PatternMatch.h:909
llvm::StringMap< PDLConstraintFunction > takeConstraintFunctions()
Definition: PatternMatch.h:652
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
This class acts as an owning reference to a module, and will automatically destroy the held module on...
Definition: BuiltinOps.h:42
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
Definition: PatternMatch.h:200
void push_back(TypeRange value)
Push a new TypeRange onto the result list.
Definition: PatternMatch.h:522
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
ModuleOp getModule()
Return the internal PDL module of this pattern.
Definition: PatternMatch.h:622
static std::unique_ptr< T > create(Args &&... args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
Definition: PatternMatch.h:274
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
PDLValue(ValueRange *value)
Definition: PatternMatch.h:413
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:41
bool operator<=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:56
OpListType::iterator iterator
Definition: Block.h:131
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
Definition: PatternMatch.h:610
LogicalResult notifyMatchFailure(Operation *op, const Twine &msg)
Definition: PatternMatch.h:810
std::enable_if_t<!llvm::is_invocable< SingleEntityFn, ArrayRef< PDLValue >, ArrayAttr, PatternRewriter & >::value > registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn)
Register a single entity constraint function.
Definition: PatternMatch.h:634
Storage type of byte-code interpreter values.
Definition: PatternMatch.h:398
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:327
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
void push_back(Attribute value)
Push a new Attribute value onto the result list.
Definition: PatternMatch.h:513
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name...
void clear()
Clear out the patterns and functions within this module.
Definition: PatternMatch.h:664
Optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
Definition: PatternMatch.h:110
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
Definition: PatternMatch.h:145
void push_back(ValueRange value)
Push a new ValueRange onto the result list.
Definition: PatternMatch.h:544
bool operator>=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:57
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:88
std::function< void(ArrayRef< PDLValue >, ArrayAttr, PatternRewriter &, PDLResultList &)> PDLRewriteFunction
A native PDL rewrite function.
Definition: PatternMatch.h:598
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
Definition: PatternMatch.h:142
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void push_back(Type value)
Push a new Type onto the result list.
Definition: PatternMatch.h:519
virtual void notifyRootReplaced(Operation *op)
These are the callback methods that subclasses can choose to implement if they would like to be notif...
Definition: PatternMatch.h:832
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
virtual LogicalResult notifyMatchFailure(Operation *op, function_ref< void(Diagnostic &)> reasonCallback)
Notify the rewriter that the pattern failed to match the given operation, and provide a callback to p...
Definition: PatternMatch.h:843
static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:195
This class acts as a special tag that makes the desire to match "any" operation type explicit...
Definition: PatternMatch.h:157
PatternBenefit()=default
RewriterBase(MLIRContext *ctx)
Initialize the builder with this rewriter as the listener.
Definition: PatternMatch.h:819
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:262
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:127
RewritePatternSet & add()
Add an instance of each of the pattern types &#39;Ts&#39;.
Definition: PatternMatch.h:961
const llvm::StringMap< PDLConstraintFunction > & getConstraintFunctions() const
Return the set of the registered constraint functions.
Definition: PatternMatch.h:649
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:779
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:401
RewriterBase(const OpBuilder &otherBuilder)
Definition: PatternMatch.h:820
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:865
bool operator>(const PatternBenefit &rhs) const
Definition: PatternMatch.h:55
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
Definition: PatternMatch.h:968
SmallVector< TypeRange > typeRanges
Memory used to store ranges held by the list.
Definition: PatternMatch.h:575
SmallVector< ValueRange > valueRanges
Definition: PatternMatch.h:576
RewritePatternSet(MLIRContext *context)
Definition: PatternMatch.h:894
This class implements iteration on the types of a given range of values.
Definition: Block.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
PDLPatternModule(OwningModuleRef pdlModule)
Construct a PDL pattern with the given module.
Definition: PatternMatch.h:615
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:80
bool isImpossibleToMatch() const
Definition: PatternMatch.h:42
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
Definition: PatternMatch.h:148
Kind getKind() const
Return the kind of this value.
Definition: PatternMatch.h:446
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
Definition: PatternMatch.h:340
virtual LogicalResult match(SourceOp op) const
Definition: PatternMatch.h:337
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
Definition: PatternMatch.h:982
IRRewriter(MLIRContext *ctx)
Definition: PatternMatch.h:867
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
bool operator==(const PatternBenefit &rhs) const
Definition: PatternMatch.h:48
This class represents an operand of an operation.
Definition: Value.h:249
void clear()
Clear out all of the held patterns in this list.
Definition: PatternMatch.h:915
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
Definition: PatternMatch.h:317
The class represents a list of PDL results, returned by a native rewrite method.
Definition: PatternMatch.h:510
This class implements the operand iterators for the Operation class.
void replaceOpWithIf(Operation *op, ValueRange newValues, llvm::unique_function< bool(OpOperand &) const > functor)
Definition: PatternMatch.h:720
PDLValue(Operation *value)
Definition: PatternMatch.h:408
std::function< LogicalResult(ArrayRef< PDLValue >, ArrayAttr, PatternRewriter &)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:591
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:167
void push_back(ValueTypeRange< OperandRange > value)
Definition: PatternMatch.h:531
void push_back(ValueTypeRange< ResultRange > value)
Definition: PatternMatch.h:535
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:912
LogicalResult notifyMatchFailure(Operation *op, const char *msg)
Definition: PatternMatch.h:813
virtual void notifyOperationRemoved(Operation *op)
This is called on an operation that a rewrite is removing, right before the operation is deleted...
Definition: PatternMatch.h:836
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:802
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:132
RewritePatternSet(PDLPatternModule &&pattern)
Definition: PatternMatch.h:902
void push_back(Value value)
Push a new Value onto the result list.
Definition: PatternMatch.h:541
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:138
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&... args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
Definition: PatternMatch.h:946
PDLValue(std::nullptr_t=nullptr)
Definition: PatternMatch.h:405
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
SmallVector< llvm::OwningArrayRef< Type > > allocatedTypeRanges
Memory allocated to store ranges in the result list whose lifetime was generated in the native functi...
Definition: PatternMatch.h:579
bool operator!=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:51
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Definition: PatternMatch.h:783
void print(raw_ostream &os) const
Print this value to the provided output stream.
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
SmallVector< llvm::OwningArrayRef< Value > > allocatedValueRanges
Definition: PatternMatch.h:580
RewritePatternSet & insert()
Add an instance of each of the pattern types &#39;Ts&#39;.
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:121
MLIRContext * getContext() const
Definition: PatternMatch.h:906
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:688
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:382
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:370
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:92