MLIR  16.0.0git
PatternMatch.h
Go to the documentation of this file.
1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
16 
17 namespace mlir {
18 
19 class PatternRewriter;
20 
21 //===----------------------------------------------------------------------===//
22 // PatternBenefit class
23 //===----------------------------------------------------------------------===//
24 
25 /// This class represents the benefit of a pattern match in a unitless scheme
26 /// that ranges from 0 (very little benefit) to 65K. The most common unit to
27 /// use here is the "number of operations matched" by the pattern.
28 ///
29 /// This also has a sentinel representation that can be used for patterns that
30 /// fail to match.
31 ///
33  enum { ImpossibleToMatchSentinel = 65535 };
34 
35 public:
36  PatternBenefit() = default;
37  PatternBenefit(unsigned benefit);
38  PatternBenefit(const PatternBenefit &) = default;
40 
42  bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
43 
44  /// If the corresponding pattern can match, return its benefit. If the
45  // corresponding pattern isImpossibleToMatch() then this aborts.
46  unsigned short getBenefit() const;
47 
48  bool operator==(const PatternBenefit &rhs) const {
49  return representation == rhs.representation;
50  }
51  bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
52  bool operator<(const PatternBenefit &rhs) const {
53  return representation < rhs.representation;
54  }
55  bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
56  bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
57  bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
58 
59 private:
60  unsigned short representation{ImpossibleToMatchSentinel};
61 };
62 
63 //===----------------------------------------------------------------------===//
64 // Pattern
65 //===----------------------------------------------------------------------===//
66 
67 /// This class contains all of the data related to a pattern, but does not
68 /// contain any methods or logic for the actual matching. This class is solely
69 /// used to interface with the metadata of a pattern, such as the benefit or
70 /// root operation.
71 class Pattern {
72  /// This enum represents the kind of value used to select the root operations
73  /// that match this pattern.
74  enum class RootKind {
75  /// The pattern root matches "any" operation.
76  Any,
77  /// The pattern root is matched using a concrete operation name.
79  /// The pattern root is matched using an interface ID.
80  InterfaceID,
81  /// The patter root is matched using a trait ID.
82  TraitID
83  };
84 
85 public:
86  /// Return a list of operations that may be generated when rewriting an
87  /// operation instance with this pattern.
88  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
89 
90  /// Return the root node that this pattern matches. Patterns that can match
91  /// multiple root types return None.
93  if (rootKind == RootKind::OperationName)
94  return OperationName::getFromOpaquePointer(rootValue);
95  return llvm::None;
96  }
97 
98  /// Return the interface ID used to match the root operation of this pattern.
99  /// If the pattern does not use an interface ID for deciding the root match,
100  /// this returns None.
102  if (rootKind == RootKind::InterfaceID)
103  return TypeID::getFromOpaquePointer(rootValue);
104  return llvm::None;
105  }
106 
107  /// Return the trait ID used to match the root operation of this pattern.
108  /// If the pattern does not use a trait ID for deciding the root match, this
109  /// returns None.
111  if (rootKind == RootKind::TraitID)
112  return TypeID::getFromOpaquePointer(rootValue);
113  return llvm::None;
114  }
115 
116  /// Return the benefit (the inverse of "cost") of matching this pattern. The
117  /// benefit of a Pattern is always static - rewrites that may have dynamic
118  /// benefit can be instantiated multiple times (different Pattern instances)
119  /// for each benefit that they may return, and be guarded by different match
120  /// condition predicates.
121  PatternBenefit getBenefit() const { return benefit; }
122 
123  /// Returns true if this pattern is known to result in recursive application,
124  /// i.e. this pattern may generate IR that also matches this pattern, but is
125  /// known to bound the recursion. This signals to a rewrite driver that it is
126  /// safe to apply this pattern recursively to generated IR.
128  return contextAndHasBoundedRecursion.getInt();
129  }
130 
131  /// Return the MLIRContext used to create this pattern.
133  return contextAndHasBoundedRecursion.getPointer();
134  }
135 
136  /// Return a readable name for this pattern. This name should only be used for
137  /// debugging purposes, and may be empty.
138  StringRef getDebugName() const { return debugName; }
139 
140  /// Set the human readable debug name used for this pattern. This name will
141  /// only be used for debugging purposes.
142  void setDebugName(StringRef name) { debugName = name; }
143 
144  /// Return the set of debug labels attached to this pattern.
145  ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
146 
147  /// Add the provided debug labels to this pattern.
149  debugLabels.append(labels.begin(), labels.end());
150  }
151  void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
152 
153 protected:
154  /// This class acts as a special tag that makes the desire to match "any"
155  /// operation type explicit. This helps to avoid unnecessary usages of this
156  /// feature, and ensures that the user is making a conscious decision.
157  struct MatchAnyOpTypeTag {};
158  /// This class acts as a special tag that makes the desire to match any
159  /// operation that implements a given interface explicit. This helps to avoid
160  /// unnecessary usages of this feature, and ensures that the user is making a
161  /// conscious decision.
163  /// This class acts as a special tag that makes the desire to match any
164  /// operation that implements a given trait explicit. This helps to avoid
165  /// unnecessary usages of this feature, and ensures that the user is making a
166  /// conscious decision.
168 
169  /// Construct a pattern with a certain benefit that matches the operation
170  /// with the given root name.
171  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
172  ArrayRef<StringRef> generatedNames = {});
173  /// Construct a pattern that may match any operation type. `generatedNames`
174  /// contains the names of operations that may be generated during a successful
175  /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
176  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
177  /// always be supplied here.
178  Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
179  ArrayRef<StringRef> generatedNames = {});
180  /// Construct a pattern that may match any operation that implements the
181  /// interface defined by the provided `interfaceID`. `generatedNames` contains
182  /// the names of operations that may be generated during a successful rewrite.
183  /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
184  /// interface" behavior is what the user actually desired,
185  /// `MatchInterfaceOpTypeTag()` should always be supplied here.
186  Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
187  PatternBenefit benefit, MLIRContext *context,
188  ArrayRef<StringRef> generatedNames = {});
189  /// Construct a pattern that may match any operation that implements the
190  /// trait defined by the provided `traitID`. `generatedNames` contains the
191  /// names of operations that may be generated during a successful rewrite.
192  /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
193  /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
194  /// always be supplied here.
195  Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
196  MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
197 
198  /// Set the flag detailing if this pattern has bounded rewrite recursion or
199  /// not.
200  void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
201  contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
202  }
203 
204 private:
205  Pattern(const void *rootValue, RootKind rootKind,
206  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
207  MLIRContext *context);
208 
209  /// The value used to match the root operation of the pattern.
210  const void *rootValue;
211  RootKind rootKind;
212 
213  /// The expected benefit of matching this pattern.
214  const PatternBenefit benefit;
215 
216  /// The context this pattern was created from, and a boolean flag indicating
217  /// whether this pattern has bounded recursion or not.
218  llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
219 
220  /// A list of the potential operations that may be generated when rewriting
221  /// an op with this pattern.
222  SmallVector<OperationName, 2> generatedOps;
223 
224  /// A readable name for this pattern. May be empty.
225  StringRef debugName;
226 
227  /// The set of debug labels attached to this pattern.
228  SmallVector<StringRef, 0> debugLabels;
229 };
230 
231 //===----------------------------------------------------------------------===//
232 // RewritePattern
233 //===----------------------------------------------------------------------===//
234 
235 /// RewritePattern is the common base class for all DAG to DAG replacements.
236 /// There are two possible usages of this class:
237 /// * Multi-step RewritePattern with "match" and "rewrite"
238 /// - By overloading the "match" and "rewrite" functions, the user can
239 /// separate the concerns of matching and rewriting.
240 /// * Single-step RewritePattern with "matchAndRewrite"
241 /// - By overloading the "matchAndRewrite" function, the user can perform
242 /// the rewrite in the same call as the match.
243 ///
244 class RewritePattern : public Pattern {
245 public:
246  virtual ~RewritePattern() = default;
247 
248  /// Rewrite the IR rooted at the specified operation with the result of
249  /// this pattern, generating any new operations with the specified
250  /// builder. If an unexpected error is encountered (an internal
251  /// compiler error), it is emitted through the normal MLIR diagnostic
252  /// hooks and the IR is left in a valid state.
253  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
254 
255  /// Attempt to match against code rooted at the specified operation,
256  /// which is the same operation code as getRootKind().
257  virtual LogicalResult match(Operation *op) const;
258 
259  /// Attempt to match against code rooted at the specified operation,
260  /// which is the same operation code as getRootKind(). If successful, this
261  /// function will automatically perform the rewrite.
263  PatternRewriter &rewriter) const {
264  if (succeeded(match(op))) {
265  rewrite(op, rewriter);
266  return success();
267  }
268  return failure();
269  }
270 
271  /// This method provides a convenient interface for creating and initializing
272  /// derived rewrite patterns of the given type `T`.
273  template <typename T, typename... Args>
274  static std::unique_ptr<T> create(Args &&...args) {
275  std::unique_ptr<T> pattern =
276  std::make_unique<T>(std::forward<Args>(args)...);
277  initializePattern<T>(*pattern);
278 
279  // Set a default debug name if one wasn't provided.
280  if (pattern->getDebugName().empty())
281  pattern->setDebugName(llvm::getTypeName<T>());
282  return pattern;
283  }
284 
285 protected:
286  /// Inherit the base constructors from `Pattern`.
287  using Pattern::Pattern;
288 
289 private:
290  /// Trait to check if T provides a `getOperationName` method.
291  template <typename T, typename... Args>
292  using has_initialize = decltype(std::declval<T>().initialize());
293  template <typename T>
294  using detect_has_initialize = llvm::is_detected<has_initialize, T>;
295 
296  /// Initialize the derived pattern by calling its `initialize` method.
297  template <typename T>
299  initializePattern(T &pattern) {
300  pattern.initialize();
301  }
302  /// Empty derived pattern initializer for patterns that do not have an
303  /// initialize method.
304  template <typename T>
306  initializePattern(T &) {}
307 
308  /// An anchor for the virtual table.
309  virtual void anchor();
310 };
311 
312 namespace detail {
313 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
314 /// allows for matching and rewriting against an instance of a derived operation
315 /// class or Interface.
316 template <typename SourceOp>
318  using RewritePattern::RewritePattern;
319 
320  /// Wrappers around the RewritePattern methods that pass the derived op type.
321  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
322  rewrite(cast<SourceOp>(op), rewriter);
323  }
324  LogicalResult match(Operation *op) const final {
325  return match(cast<SourceOp>(op));
326  }
328  PatternRewriter &rewriter) const final {
329  return matchAndRewrite(cast<SourceOp>(op), rewriter);
330  }
331 
332  /// Rewrite and Match methods that operate on the SourceOp type. These must be
333  /// overridden by the derived pattern class.
334  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
335  llvm_unreachable("must override rewrite or matchAndRewrite");
336  }
337  virtual LogicalResult match(SourceOp op) const {
338  llvm_unreachable("must override match or matchAndRewrite");
339  }
340  virtual LogicalResult matchAndRewrite(SourceOp op,
341  PatternRewriter &rewriter) const {
342  if (succeeded(match(op))) {
343  rewrite(op, rewriter);
344  return success();
345  }
346  return failure();
347  }
348 };
349 } // namespace detail
350 
351 /// OpRewritePattern is a wrapper around RewritePattern that allows for
352 /// matching and rewriting against an instance of a derived operation class as
353 /// opposed to a raw Operation.
354 template <typename SourceOp>
356  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
357  /// Patterns must specify the root operation name they match against, and can
358  /// also specify the benefit of the pattern matching and a list of generated
359  /// ops.
361  ArrayRef<StringRef> generatedNames = {})
363  SourceOp::getOperationName(), benefit, context, generatedNames) {}
364 };
365 
366 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
367 /// matching and rewriting against an instance of an operation interface instead
368 /// of a raw Operation.
369 template <typename SourceOp>
371  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
373  : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
374  Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
375  benefit, context) {}
376 };
377 
378 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
379 /// matching and rewriting against instances of an operation that possess a
380 /// given trait.
381 template <template <typename> class TraitType>
383 public:
385  : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
386  benefit, context) {}
387 };
388 
389 //===----------------------------------------------------------------------===//
390 // RewriterBase
391 //===----------------------------------------------------------------------===//
392 
393 /// This class coordinates the application of a rewrite on a set of IR,
394 /// providing a way for clients to track mutations and create new operations.
395 /// This class serves as a common API for IR mutation between pattern rewrites
396 /// and non-pattern rewrites, and facilitates the development of shared
397 /// IR transformation utilities.
399 public:
400  /// Move the blocks that belong to "region" before the given position in
401  /// another region "parent". The two regions must be different. The caller
402  /// is responsible for creating or updating the operation transferring flow
403  /// of control to the region and passing it the correct block arguments.
404  virtual void inlineRegionBefore(Region &region, Region &parent,
405  Region::iterator before);
406  void inlineRegionBefore(Region &region, Block *before);
407 
408  /// Clone the blocks that belong to "region" before the given position in
409  /// another region "parent". The two regions must be different. The caller is
410  /// responsible for creating or updating the operation transferring flow of
411  /// control to the region and passing it the correct block arguments.
412  virtual void cloneRegionBefore(Region &region, Region &parent,
413  Region::iterator before,
414  BlockAndValueMapping &mapping);
415  void cloneRegionBefore(Region &region, Region &parent,
416  Region::iterator before);
417  void cloneRegionBefore(Region &region, Block *before);
418 
419  /// This method replaces the uses of the results of `op` with the values in
420  /// `newValues` when the provided `functor` returns true for a specific use.
421  /// The number of values in `newValues` is required to match the number of
422  /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
423  /// the uses of `op` were replaced. Note that in some rewriters, the given
424  /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
425  /// As such, the function should not capture by reference and instead use
426  /// value capture as necessary.
427  virtual void
428  replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
429  llvm::unique_function<bool(OpOperand &) const> functor);
430  void replaceOpWithIf(Operation *op, ValueRange newValues,
431  llvm::unique_function<bool(OpOperand &) const> functor) {
432  replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
433  std::move(functor));
434  }
435 
436  /// This method replaces the uses of the results of `op` with the values in
437  /// `newValues` when a use is nested within the given `block`. The number of
438  /// values in `newValues` is required to match the number of results of `op`.
439  /// If all uses of this operation are replaced, the operation is erased.
440  void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
441  bool *allUsesReplaced = nullptr);
442 
443  /// This method replaces the results of the operation with the specified list
444  /// of values. The number of provided values must match the number of results
445  /// of the operation.
446  virtual void replaceOp(Operation *op, ValueRange newValues);
447 
448  /// Replaces the result op with a new op that is created without verification.
449  /// The result values of the two ops must be the same types.
450  template <typename OpTy, typename... Args>
451  OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
452  auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
453  replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
454  return newOp;
455  }
456 
457  /// This method erases an operation that is known to have no uses.
458  virtual void eraseOp(Operation *op);
459 
460  /// This method erases all operations in a block.
461  virtual void eraseBlock(Block *block);
462 
463  /// Merge the operations of block 'source' into the end of block 'dest'.
464  /// 'source's predecessors must either be empty or only contain 'dest`.
465  /// 'argValues' is used to replace the block arguments of 'source' after
466  /// merging.
467  virtual void mergeBlocks(Block *source, Block *dest,
468  ValueRange argValues = llvm::None);
469 
470  // Merge the operations of block 'source' before the operation 'op'. Source
471  // block should not have existing predecessors or successors.
472  void mergeBlockBefore(Block *source, Operation *op,
473  ValueRange argValues = llvm::None);
474 
475  /// Split the operations starting at "before" (inclusive) out of the given
476  /// block into a new block, and return it.
477  virtual Block *splitBlock(Block *block, Block::iterator before);
478 
479  /// This method is used to notify the rewriter that an in-place operation
480  /// modification is about to happen. A call to this function *must* be
481  /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
482  /// This is a minor efficiency win (it avoids creating a new operation and
483  /// removing the old one) but also often allows simpler code in the client.
484  virtual void startRootUpdate(Operation *op) {}
485 
486  /// This method is used to signal the end of a root update on the given
487  /// operation. This can only be called on operations that were provided to a
488  /// call to `startRootUpdate`.
489  virtual void finalizeRootUpdate(Operation *op) {}
490 
491  /// This method cancels a pending root update. This can only be called on
492  /// operations that were provided to a call to `startRootUpdate`.
493  virtual void cancelRootUpdate(Operation *op) {}
494 
495  /// This method is a utility wrapper around a root update of an operation. It
496  /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
497  /// callable.
498  template <typename CallableT>
499  void updateRootInPlace(Operation *root, CallableT &&callable) {
500  startRootUpdate(root);
501  callable();
502  finalizeRootUpdate(root);
503  }
504 
505  /// Find uses of `from` and replace it with `to`. It also marks every modified
506  /// uses and notifies the rewriter that an in-place operation modification is
507  /// about to happen.
508  void replaceAllUsesWith(Value from, Value to);
509 
510  /// Used to notify the rewriter that the IR failed to be rewritten because of
511  /// a match failure, and provide a callback to populate a diagnostic with the
512  /// reason why the failure occurred. This method allows for derived rewriters
513  /// to optionally hook into the reason why a rewrite failed, and display it to
514  /// users.
515  template <typename CallbackT>
517  notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
518 #ifndef NDEBUG
519  return notifyMatchFailure(loc,
520  function_ref<void(Diagnostic &)>(reasonCallback));
521 #else
522  return failure();
523 #endif
524  }
525  template <typename CallbackT>
527  notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
528  return notifyMatchFailure(op->getLoc(),
529  function_ref<void(Diagnostic &)>(reasonCallback));
530  }
531  template <typename ArgT>
532  LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
533  return notifyMatchFailure(std::forward<ArgT>(arg),
534  [&](Diagnostic &diag) { diag << msg; });
535  }
536  template <typename ArgT>
537  LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
538  return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
539  }
540 
541 protected:
542  /// Initialize the builder with this rewriter as the listener.
543  explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
544  explicit RewriterBase(const OpBuilder &otherBuilder)
545  : OpBuilder(otherBuilder) {
546  setListener(this);
547  }
548  ~RewriterBase() override;
549 
550  /// These are the callback methods that subclasses can choose to implement if
551  /// they would like to be notified about certain types of mutations.
552 
553  /// Notify the rewriter that the specified operation is about to be replaced
554  /// with the set of values potentially produced by new operations. This is
555  /// called before the uses of the operation have been changed.
556  virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {}
557 
558  /// This is called on an operation that a rewrite is removing, right before
559  /// the operation is deleted. At this point, the operation has zero uses.
560  virtual void notifyOperationRemoved(Operation *op) {}
561 
562  /// Notify the rewriter that the pattern failed to match the given operation,
563  /// and provide a callback to populate a diagnostic with the reason why the
564  /// failure occurred. This method allows for derived rewriters to optionally
565  /// hook into the reason why a rewrite failed, and display it to users.
566  virtual LogicalResult
568  function_ref<void(Diagnostic &)> reasonCallback) {
569  return failure();
570  }
571 
572 private:
573  void operator=(const RewriterBase &) = delete;
574  RewriterBase(const RewriterBase &) = delete;
575 
576  /// 'op' and 'newOp' are known to have the same number of results, replace the
577  /// uses of op with uses of newOp.
578  void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
579 };
580 
581 //===----------------------------------------------------------------------===//
582 // IRRewriter
583 //===----------------------------------------------------------------------===//
584 
585 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
586 /// providing a way to keep track of the mutations made to the IR. This class
587 /// should only be used in situations where another `RewriterBase` instance,
588 /// such as a `PatternRewriter`, is not available.
589 class IRRewriter : public RewriterBase {
590 public:
591  explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
592  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
593 };
594 
595 //===----------------------------------------------------------------------===//
596 // PatternRewriter
597 //===----------------------------------------------------------------------===//
598 
599 /// A special type of `RewriterBase` that coordinates the application of a
600 /// rewrite pattern on the current IR being matched, providing a way to keep
601 /// track of any mutations made. This class should be used to perform all
602 /// necessary IR mutations within a rewrite pattern, as the pattern driver may
603 /// be tracking various state that would be invalidated when a mutation takes
604 /// place.
606 public:
608 
609  /// A hook used to indicate if the pattern rewriter can recover from failure
610  /// during the rewrite stage of a pattern. For example, if the pattern
611  /// rewriter supports rollback, it may progress smoothly even if IR was
612  /// changed during the rewrite.
613  virtual bool canRecoverFromRewriteFailure() const { return false; }
614 };
615 
616 //===----------------------------------------------------------------------===//
617 // PDL Patterns
618 //===----------------------------------------------------------------------===//
619 
620 //===----------------------------------------------------------------------===//
621 // PDLValue
622 
623 /// Storage type of byte-code interpreter values. These are passed to constraint
624 /// functions as arguments.
625 class PDLValue {
626 public:
627  /// The underlying kind of a PDL value.
629 
630  /// Construct a new PDL value.
631  PDLValue(const PDLValue &other) = default;
632  PDLValue(std::nullptr_t = nullptr) {}
634  : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
635  PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
636  PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
637  PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
639  : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
640  PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
641 
642  /// Returns true if the type of the held value is `T`.
643  template <typename T>
644  bool isa() const {
645  assert(value && "isa<> used on a null value");
646  return kind == getKindOf<T>();
647  }
648 
649  /// Attempt to dynamically cast this value to type `T`, returns null if this
650  /// value is not an instance of `T`.
651  template <typename T,
652  typename ResultT = std::conditional_t<
654  ResultT dyn_cast() const {
655  return isa<T>() ? castImpl<T>() : ResultT();
656  }
657 
658  /// Cast this value to type `T`, asserts if this value is not an instance of
659  /// `T`.
660  template <typename T>
661  T cast() const {
662  assert(isa<T>() && "expected value to be of type `T`");
663  return castImpl<T>();
664  }
665 
666  /// Get an opaque pointer to the value.
667  const void *getAsOpaquePointer() const { return value; }
668 
669  /// Return if this value is null or not.
670  explicit operator bool() const { return value; }
671 
672  /// Return the kind of this value.
673  Kind getKind() const { return kind; }
674 
675  /// Print this value to the provided output stream.
676  void print(raw_ostream &os) const;
677 
678  /// Print the specified value kind to an output stream.
679  static void print(raw_ostream &os, Kind kind);
680 
681 private:
682  /// Find the index of a given type in a range of other types.
683  template <typename...>
684  struct index_of_t;
685  template <typename T, typename... R>
686  struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
687  template <typename T, typename F, typename... R>
688  struct index_of_t<T, F, R...>
689  : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
690 
691  /// Return the kind used for the given T.
692  template <typename T>
693  static Kind getKindOf() {
694  return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
695  TypeRange, Value, ValueRange>::value);
696  }
697 
698  /// The internal implementation of `cast`, that returns the underlying value
699  /// as the given type `T`.
700  template <typename T>
702  castImpl() const {
703  return T::getFromOpaquePointer(value);
704  }
705  template <typename T>
707  castImpl() const {
708  return *reinterpret_cast<T *>(const_cast<void *>(value));
709  }
710  template <typename T>
711  std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
712  return reinterpret_cast<T>(const_cast<void *>(value));
713  }
714 
715  /// The internal opaque representation of a PDLValue.
716  const void *value{nullptr};
717  /// The kind of the opaque value.
718  Kind kind{Kind::Attribute};
719 };
720 
721 inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
722  value.print(os);
723  return os;
724 }
725 
726 inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
727  PDLValue::print(os, kind);
728  return os;
729 }
730 
731 //===----------------------------------------------------------------------===//
732 // PDLResultList
733 
734 /// The class represents a list of PDL results, returned by a native rewrite
735 /// method. It provides the mechanism with which to pass PDLValues back to the
736 /// PDL bytecode.
738 public:
739  /// Push a new Attribute value onto the result list.
740  void push_back(Attribute value) { results.push_back(value); }
741 
742  /// Push a new Operation onto the result list.
743  void push_back(Operation *value) { results.push_back(value); }
744 
745  /// Push a new Type onto the result list.
746  void push_back(Type value) { results.push_back(value); }
747 
748  /// Push a new TypeRange onto the result list.
750  // The lifetime of a TypeRange can't be guaranteed, so we'll need to
751  // allocate a storage for it.
752  llvm::OwningArrayRef<Type> storage(value.size());
753  llvm::copy(value, storage.begin());
754  allocatedTypeRanges.emplace_back(std::move(storage));
755  typeRanges.push_back(allocatedTypeRanges.back());
756  results.push_back(&typeRanges.back());
757  }
759  typeRanges.push_back(value);
760  results.push_back(&typeRanges.back());
761  }
763  typeRanges.push_back(value);
764  results.push_back(&typeRanges.back());
765  }
766 
767  /// Push a new Value onto the result list.
768  void push_back(Value value) { results.push_back(value); }
769 
770  /// Push a new ValueRange onto the result list.
772  // The lifetime of a ValueRange can't be guaranteed, so we'll need to
773  // allocate a storage for it.
774  llvm::OwningArrayRef<Value> storage(value.size());
775  llvm::copy(value, storage.begin());
776  allocatedValueRanges.emplace_back(std::move(storage));
777  valueRanges.push_back(allocatedValueRanges.back());
778  results.push_back(&valueRanges.back());
779  }
781  valueRanges.push_back(value);
782  results.push_back(&valueRanges.back());
783  }
785  valueRanges.push_back(value);
786  results.push_back(&valueRanges.back());
787  }
788 
789 protected:
790  /// Create a new result list with the expected number of results.
791  PDLResultList(unsigned maxNumResults) {
792  // For now just reserve enough space for all of the results. We could do
793  // separate counts per range type, but it isn't really worth it unless there
794  // are a "large" number of results.
795  typeRanges.reserve(maxNumResults);
796  valueRanges.reserve(maxNumResults);
797  }
798 
799  /// The PDL results held by this list.
801  /// Memory used to store ranges held by the list.
804  /// Memory allocated to store ranges in the result list whose lifetime was
805  /// generated in the native function.
808 };
809 
810 //===----------------------------------------------------------------------===//
811 // PDLPatternConfig
812 
813 /// An individual configuration for a pattern, which can be accessed by native
814 /// functions via the PDLPatternConfigSet. This allows for injecting additional
815 /// configuration into PDL patterns that is specific to certain compilation
816 /// flows.
818 public:
819  virtual ~PDLPatternConfig() = default;
820 
821  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
822  /// pattern. These can be used to setup any specific state necessary for the
823  /// rewrite.
824  virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
825  virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
826 
827  /// Return the TypeID that represents this configuration.
828  TypeID getTypeID() const { return id; }
829 
830 protected:
831  PDLPatternConfig(TypeID id) : id(id) {}
832 
833 private:
834  TypeID id;
835 };
836 
837 /// This class provides a base class for users implementing a type of pattern
838 /// configuration.
839 template <typename T>
841 public:
842  /// Support LLVM style casting.
843  static bool classof(const PDLPatternConfig *config) {
844  return config->getTypeID() == getConfigID();
845  }
846 
847  /// Return the type id used for this configuration.
848  static TypeID getConfigID() { return TypeID::get<T>(); }
849 
850 protected:
852 };
853 
854 /// This class contains a set of configurations for a specific pattern.
855 /// Configurations are uniqued by TypeID, meaning that only one configuration of
856 /// each type is allowed.
858 public:
859  PDLPatternConfigSet() = default;
860 
861  /// Construct a set with the given configurations.
862  template <typename... ConfigsT>
863  PDLPatternConfigSet(ConfigsT &&...configs) {
864  (addConfig(std::forward<ConfigsT>(configs)), ...);
865  }
866 
867  /// Get the configuration defined by the given type. Asserts that the
868  /// configuration of the provided type exists.
869  template <typename T>
870  const T &get() const {
871  const T *config = tryGet<T>();
872  assert(config && "configuration not found");
873  return *config;
874  }
875 
876  /// Get the configuration defined by the given type, returns nullptr if the
877  /// configuration does not exist.
878  template <typename T>
879  const T *tryGet() const {
880  for (const auto &configIt : configs)
881  if (const T *config = dyn_cast<T>(configIt.get()))
882  return config;
883  return nullptr;
884  }
885 
886  /// Notify the configurations within this set at the beginning or end of a
887  /// rewrite of a matched pattern.
889  for (const auto &config : configs)
890  config->notifyRewriteBegin(rewriter);
891  }
893  for (const auto &config : configs)
894  config->notifyRewriteEnd(rewriter);
895  }
896 
897 protected:
898  /// Add a configuration to the set.
899  template <typename T>
900  void addConfig(T &&config) {
901  assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
902  configs.emplace_back(
903  std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
904  }
905 
906  /// The set of configurations for this pattern. This uses a vector instead of
907  /// a map with the expectation that the number of configurations per set is
908  /// small (<= 1).
910 };
911 
912 //===----------------------------------------------------------------------===//
913 // PDLPatternModule
914 
915 /// A generic PDL pattern constraint function. This function applies a
916 /// constraint to a given set of opaque PDLValue entities. Returns success if
917 /// the constraint successfully held, failure otherwise.
920 /// A native PDL rewrite function. This function performs a rewrite on the
921 /// given set of values. Any results from this rewrite that should be passed
922 /// back to PDL should be added to the provided result list. This method is only
923 /// invoked when the corresponding match was successful. Returns failure if an
924 /// invariant of the rewrite was broken (certain rewriters may recover from
925 /// partial pattern application).
926 using PDLRewriteFunction = std::function<LogicalResult(
928 
929 namespace detail {
930 namespace pdl_function_builder {
931 /// A utility variable that always resolves to false. This is useful for static
932 /// asserts that are always false, but only should fire in certain templated
933 /// constructs. For example, if a templated function should never be called, the
934 /// function could be defined as:
935 ///
936 /// template <typename T>
937 /// void foo() {
938 /// static_assert(always_false<T>, "This function should never be called");
939 /// }
940 ///
941 template <class... T>
942 constexpr bool always_false = false;
943 
944 //===----------------------------------------------------------------------===//
945 // PDL Function Builder: Type Processing
946 //===----------------------------------------------------------------------===//
947 
948 /// This struct provides a convenient way to determine how to process a given
949 /// type as either a PDL parameter, or a result value. This allows for
950 /// supporting complex types in constraint and rewrite functions, without
951 /// requiring the user to hand-write the necessary glue code themselves.
952 /// Specializations of this class should implement the following methods to
953 /// enable support as a PDL argument or result type:
954 ///
955 /// static LogicalResult verifyAsArg(
956 /// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
957 /// size_t argIdx);
958 ///
959 /// * This method verifies that the given PDLValue is valid for use as a
960 /// value of `T`.
961 ///
962 /// static T processAsArg(PDLValue pdlValue);
963 ///
964 /// * This method processes the given PDLValue as a value of `T`.
965 ///
966 /// static void processAsResult(PatternRewriter &, PDLResultList &results,
967 /// const T &value);
968 ///
969 /// * This method processes the given value of `T` as the result of a
970 /// function invocation. The method should package the value into an
971 /// appropriate form and append it to the given result list.
972 ///
973 /// If the type `T` is based on a higher order value, consider using
974 /// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
975 /// the implementation.
976 ///
977 template <typename T, typename Enable = void>
979 
980 /// This struct provides a simplified model for processing types that are based
981 /// on another type, e.g. APInt is based on the handling for IntegerAttr. This
982 /// allows for building the necessary processing functions on top of the base
983 /// value instead of a PDLValue. Derived users should implement the following
984 /// (which subsume the ProcessPDLValue variants):
985 ///
986 /// static LogicalResult verifyAsArg(
987 /// function_ref<LogicalResult(const Twine &)> errorFn,
988 /// const BaseT &baseValue, size_t argIdx);
989 ///
990 /// * This method verifies that the given PDLValue is valid for use as a
991 /// value of `T`.
992 ///
993 /// static T processAsArg(BaseT baseValue);
994 ///
995 /// * This method processes the given base value as a value of `T`.
996 ///
997 template <typename T, typename BaseT>
999  static LogicalResult
1000  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
1001  PDLValue pdlValue, size_t argIdx) {
1002  // Verify the base class before continuing.
1003  if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
1004  return failure();
1006  errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
1007  }
1008  static T processAsArg(PDLValue pdlValue) {
1011  }
1012 
1013  /// Explicitly add the expected parent API to ensure the parent class
1014  /// implements the necessary API (and doesn't implicitly inherit it from
1015  /// somewhere else).
1016  static LogicalResult
1018  size_t argIdx) {
1019  return success();
1020  }
1021  static T processAsArg(BaseT baseValue);
1022 };
1023 
1024 /// This struct provides a simplified model for processing types that have
1025 /// "builtin" PDLValue support:
1026 /// * Attribute, Operation *, Type, TypeRange, ValueRange
1027 template <typename T>
1029  static LogicalResult
1030  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
1031  PDLValue pdlValue, size_t argIdx) {
1032  if (pdlValue)
1033  return success();
1034  return errorFn("expected a non-null value for argument " + Twine(argIdx) +
1035  " of type: " + llvm::getTypeName<T>());
1036  }
1037 
1038  static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
1040  T value) {
1041  results.push_back(value);
1042  }
1043 };
1044 
1045 /// This struct provides a simplified model for processing types that inherit
1046 /// from builtin PDLValue types. For example, derived attributes like
1047 /// IntegerAttr, derived types like IntegerType, derived operations like
1048 /// ModuleOp, Interfaces, etc.
1049 template <typename T, typename BaseT>
1051  static LogicalResult
1052  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
1053  BaseT baseValue, size_t argIdx) {
1054  return TypeSwitch<BaseT, LogicalResult>(baseValue)
1055  .Case([&](T) { return success(); })
1056  .Default([&](BaseT) {
1057  return errorFn("expected argument " + Twine(argIdx) +
1058  " to be of type: " + llvm::getTypeName<T>());
1059  });
1060  }
1062 
1063  static T processAsArg(BaseT baseValue) {
1064  return baseValue.template cast<T>();
1065  }
1067 
1069  T value) {
1070  results.push_back(value);
1071  }
1072 };
1073 
1074 //===----------------------------------------------------------------------===//
1075 // Attribute
1076 
1077 template <>
1078 struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
1079 template <typename T>
1081  std::enable_if_t<std::is_base_of<Attribute, T>::value>>
1082  : public ProcessDerivedPDLValue<T, Attribute> {};
1083 
1084 /// Handling for various Attribute value types.
1085 template <>
1086 struct ProcessPDLValue<StringRef>
1087  : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
1088  static StringRef processAsArg(StringAttr value) { return value.getValue(); }
1090 
1091  static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
1092  StringRef value) {
1093  results.push_back(rewriter.getStringAttr(value));
1094  }
1095 };
1096 template <>
1097 struct ProcessPDLValue<std::string>
1098  : public ProcessPDLValueBasedOn<std::string, StringAttr> {
1099  template <typename T>
1100  static std::string processAsArg(T value) {
1101  static_assert(always_false<T>,
1102  "`std::string` arguments require a string copy, use "
1103  "`StringRef` for string-like arguments instead");
1104  return {};
1105  }
1106  static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
1107  StringRef value) {
1108  results.push_back(rewriter.getStringAttr(value));
1109  }
1110 };
1111 
1112 //===----------------------------------------------------------------------===//
1113 // Operation
1114 
1115 template <>
1118 template <typename T>
1119 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
1120  : public ProcessDerivedPDLValue<T, Operation *> {
1121  static T processAsArg(Operation *value) { return cast<T>(value); }
1122 };
1123 
1124 //===----------------------------------------------------------------------===//
1125 // Type
1126 
1127 template <>
1128 struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
1129 template <typename T>
1130 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
1131  : public ProcessDerivedPDLValue<T, Type> {};
1132 
1133 //===----------------------------------------------------------------------===//
1134 // TypeRange
1135 
1136 template <>
1137 struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
1138 template <>
1142  results.push_back(types);
1143  }
1144 };
1145 template <>
1149  results.push_back(types);
1150  }
1151 };
1152 template <unsigned N>
1155  SmallVector<Type, N> values) {
1156  results.push_back(TypeRange(values));
1157  }
1158 };
1159 
1160 //===----------------------------------------------------------------------===//
1161 // Value
1162 
1163 template <>
1164 struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
1165 
1166 //===----------------------------------------------------------------------===//
1167 // ValueRange
1168 
1169 template <>
1170 struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
1171 };
1172 template <>
1175  OperandRange values) {
1176  results.push_back(values);
1177  }
1178 };
1179 template <>
1182  ResultRange values) {
1183  results.push_back(values);
1184  }
1185 };
1186 template <unsigned N>
1189  SmallVector<Value, N> values) {
1190  results.push_back(ValueRange(values));
1191  }
1192 };
1193 
1194 //===----------------------------------------------------------------------===//
1195 // PDL Function Builder: Argument Handling
1196 //===----------------------------------------------------------------------===//
1197 
1198 /// Validate the given PDLValues match the constraints defined by the argument
1199 /// types of the given function. In the case of failure, a match failure
1200 /// diagnostic is emitted.
1201 /// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
1202 /// does not currently preserve Constraint application ordering.
1203 template <typename PDLFnT, std::size_t... I>
1205  std::index_sequence<I...>) {
1206  using FnTraitsT = llvm::function_traits<PDLFnT>;
1207 
1208  auto errorFn = [&](const Twine &msg) {
1209  return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
1210  };
1211  return success(
1212  (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1213  verifyAsArg(errorFn, values[I], I)) &&
1214  ...));
1215 }
1216 
1217 /// Assert that the given PDLValues match the constraints defined by the
1218 /// arguments of the given function. In the case of failure, a fatal error
1219 /// is emitted.
1220 template <typename PDLFnT, std::size_t... I>
1222  std::index_sequence<I...>) {
1223  // We only want to do verification in debug builds, same as with `assert`.
1224 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1225  using FnTraitsT = llvm::function_traits<PDLFnT>;
1226  auto errorFn = [&](const Twine &msg) -> LogicalResult {
1227  llvm::report_fatal_error(msg);
1228  };
1229  (void)errorFn;
1230  assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1231  verifyAsArg(errorFn, values[I], I)) &&
1232  ...));
1233 #endif
1234  (void)values;
1235 }
1236 
1237 //===----------------------------------------------------------------------===//
1238 // PDL Function Builder: Results Handling
1239 //===----------------------------------------------------------------------===//
1240 
1241 /// Store a single result within the result list.
1242 template <typename T>
1244  PDLResultList &results, T &&value) {
1245  ProcessPDLValue<T>::processAsResult(rewriter, results,
1246  std::forward<T>(value));
1247  return success();
1248 }
1249 
1250 /// Store a std::pair<> as individual results within the result list.
1251 template <typename T1, typename T2>
1253  PDLResultList &results,
1254  std::pair<T1, T2> &&pair) {
1255  if (failed(processResults(rewriter, results, std::move(pair.first))) ||
1256  failed(processResults(rewriter, results, std::move(pair.second))))
1257  return failure();
1258  return success();
1259 }
1260 
1261 /// Store a std::tuple<> as individual results within the result list.
1262 template <typename... Ts>
1264  PDLResultList &results,
1265  std::tuple<Ts...> &&tuple) {
1266  auto applyFn = [&](auto &&...args) {
1267  return (succeeded(processResults(rewriter, results, std::move(args))) &&
1268  ...);
1269  };
1270  return success(std::apply(applyFn, std::move(tuple)));
1271 }
1272 
1273 /// Handle LogicalResult propagation.
1275  PDLResultList &results,
1276  LogicalResult &&result) {
1277  return result;
1278 }
1279 template <typename T>
1281  PDLResultList &results,
1282  FailureOr<T> &&result) {
1283  if (failed(result))
1284  return failure();
1285  return processResults(rewriter, results, std::move(*result));
1286 }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // PDL Constraint Builder
1290 //===----------------------------------------------------------------------===//
1291 
1292 /// Process the arguments of a native constraint and invoke it.
1293 template <typename PDLFnT, std::size_t... I,
1294  typename FnTraitsT = llvm::function_traits<PDLFnT>>
1295 typename FnTraitsT::result_t
1297  ArrayRef<PDLValue> values,
1298  std::index_sequence<I...>) {
1299  return fn(
1300  rewriter,
1301  (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1302  values[I]))...);
1303 }
1304 
1305 /// Build a constraint function from the given function `ConstraintFnT`. This
1306 /// allows for enabling the user to define simpler, more direct constraint
1307 /// functions without needing to handle the low-level PDL goop.
1308 ///
1309 /// If the constraint function is already in the correct form, we just forward
1310 /// it directly.
1311 template <typename ConstraintFnT>
1312 std::enable_if_t<
1315 buildConstraintFn(ConstraintFnT &&constraintFn) {
1316  return std::forward<ConstraintFnT>(constraintFn);
1317 }
1318 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
1319 /// we desire.
1320 template <typename ConstraintFnT>
1321 std::enable_if_t<
1324 buildConstraintFn(ConstraintFnT &&constraintFn) {
1325  return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
1326  PatternRewriter &rewriter,
1327  ArrayRef<PDLValue> values) -> LogicalResult {
1328  auto argIndices = std::make_index_sequence<
1329  llvm::function_traits<ConstraintFnT>::num_args - 1>();
1330  if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
1331  return failure();
1332  return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
1333  argIndices);
1334  };
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // PDL Rewrite Builder
1339 //===----------------------------------------------------------------------===//
1340 
1341 /// Process the arguments of a native rewrite and invoke it.
1342 /// This overload handles the case of no return values.
1343 template <typename PDLFnT, std::size_t... I,
1344  typename FnTraitsT = llvm::function_traits<PDLFnT>>
1346  LogicalResult>
1349  std::index_sequence<I...>) {
1350  fn(rewriter,
1351  (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1352  values[I]))...);
1353  return success();
1354 }
1355 /// This overload handles the case of return values, which need to be packaged
1356 /// into the result list.
1357 template <typename PDLFnT, std::size_t... I,
1358  typename FnTraitsT = llvm::function_traits<PDLFnT>>
1360  LogicalResult>
1362  PDLResultList &results, ArrayRef<PDLValue> values,
1363  std::index_sequence<I...>) {
1364  return processResults(
1365  rewriter, results,
1366  fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1367  processAsArg(values[I]))...));
1368  (void)values;
1369 }
1370 
1371 /// Build a rewrite function from the given function `RewriteFnT`. This
1372 /// allows for enabling the user to define simpler, more direct rewrite
1373 /// functions without needing to handle the low-level PDL goop.
1374 ///
1375 /// If the rewrite function is already in the correct form, we just forward
1376 /// it directly.
1377 template <typename RewriteFnT>
1380 buildRewriteFn(RewriteFnT &&rewriteFn) {
1381  return std::forward<RewriteFnT>(rewriteFn);
1382 }
1383 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
1384 /// we desire.
1385 template <typename RewriteFnT>
1388 buildRewriteFn(RewriteFnT &&rewriteFn) {
1389  return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
1390  PatternRewriter &rewriter, PDLResultList &results,
1391  ArrayRef<PDLValue> values) {
1392  auto argIndices =
1393  std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1394  1>();
1395  assertArgs<RewriteFnT>(rewriter, values, argIndices);
1396  return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
1397  argIndices);
1398  };
1399 }
1400 
1401 } // namespace pdl_function_builder
1402 } // namespace detail
1403 
1404 //===----------------------------------------------------------------------===//
1405 // PDLPatternModule
1406 
1407 /// This class contains all of the necessary data for a set of PDL patterns, or
1408 /// pattern rewrites specified in the form of the PDL dialect. This PDL module
1409 /// contained by this pattern may contain any number of `pdl.pattern`
1410 /// operations.
1412 public:
1413  PDLPatternModule() = default;
1414 
1415  /// Construct a PDL pattern with the given module and configurations.
1417  : pdlModule(std::move(module)) {}
1418  template <typename... ConfigsT>
1419  PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
1420  : PDLPatternModule(std::move(module)) {
1421  auto configSet = std::make_unique<PDLPatternConfigSet>(
1422  std::forward<ConfigsT>(patternConfigs)...);
1423  attachConfigToPatterns(*pdlModule, *configSet);
1424  configs.emplace_back(std::move(configSet));
1425  }
1426 
1427  /// Merge the state in `other` into this pattern module.
1428  void mergeIn(PDLPatternModule &&other);
1429 
1430  /// Return the internal PDL module of this pattern.
1431  ModuleOp getModule() { return pdlModule.get(); }
1432 
1433  //===--------------------------------------------------------------------===//
1434  // Function Registry
1435 
1436  /// Register a constraint function with PDL. A constraint function may be
1437  /// specified in one of two ways:
1438  ///
1439  /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
1440  ///
1441  /// In this overload the arguments of the constraint function are passed via
1442  /// the low-level PDLValue form.
1443  ///
1444  /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
1445  ///
1446  /// In this form the arguments of the constraint function are passed via the
1447  /// expected high level C++ type. In this form, the framework will
1448  /// automatically unwrap PDLValues and convert them to the expected ValueTs.
1449  /// For example, if the constraint function accepts a `Operation *`, the
1450  /// framework will automatically cast the input PDLValue. In the case of a
1451  /// `StringRef`, the framework will automatically unwrap the argument as a
1452  /// StringAttr and pass the underlying string value. To see the full list of
1453  /// supported types, or to see how to add handling for custom types, view
1454  /// the definition of `ProcessPDLValue` above.
1455  void registerConstraintFunction(StringRef name,
1456  PDLConstraintFunction constraintFn);
1457  template <typename ConstraintFnT>
1458  void registerConstraintFunction(StringRef name,
1459  ConstraintFnT &&constraintFn) {
1462  std::forward<ConstraintFnT>(constraintFn)));
1463  }
1464 
1465  /// Register a rewrite function with PDL. A rewrite function may be specified
1466  /// in one of two ways:
1467  ///
1468  /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
1469  ///
1470  /// In this overload the arguments of the constraint function are passed via
1471  /// the low-level PDLValue form, and the results are manually appended to
1472  /// the given result list.
1473  ///
1474  /// * `ResultT (PatternRewriter &, ValueTs... values)`
1475  ///
1476  /// In this form the arguments and result of the rewrite function are passed
1477  /// via the expected high level C++ type. In this form, the framework will
1478  /// automatically unwrap the PDLValues arguments and convert them to the
1479  /// expected ValueTs. It will also automatically handle the processing and
1480  /// packaging of the result value to the result list. For example, if the
1481  /// rewrite function takes a `Operation *`, the framework will automatically
1482  /// cast the input PDLValue. In the case of a `StringRef`, the framework
1483  /// will automatically unwrap the argument as a StringAttr and pass the
1484  /// underlying string value. In the reverse case, if the rewrite returns a
1485  /// StringRef or std::string, it will automatically package this as a
1486  /// StringAttr and append it to the result list. To see the full list of
1487  /// supported types, or to see how to add handling for custom types, view
1488  /// the definition of `ProcessPDLValue` above.
1489  void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
1490  template <typename RewriteFnT>
1491  void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
1493  std::forward<RewriteFnT>(rewriteFn)));
1494  }
1495 
1496  /// Return the set of the registered constraint functions.
1497  const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
1498  return constraintFunctions;
1499  }
1500  llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
1501  return constraintFunctions;
1502  }
1503  /// Return the set of the registered rewrite functions.
1504  const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
1505  return rewriteFunctions;
1506  }
1507  llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
1508  return rewriteFunctions;
1509  }
1510 
1511  /// Return the set of the registered pattern configs.
1513  return std::move(configs);
1514  }
1516  return std::move(configMap);
1517  }
1518 
1519  /// Clear out the patterns and functions within this module.
1520  void clear() {
1521  pdlModule = nullptr;
1522  constraintFunctions.clear();
1523  rewriteFunctions.clear();
1524  }
1525 
1526 private:
1527  /// Attach the given pattern config set to the patterns defined within the
1528  /// given module.
1529  void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
1530 
1531  /// The module containing the `pdl.pattern` operations.
1532  OwningOpRef<ModuleOp> pdlModule;
1533 
1534  /// The set of configuration sets referenced by patterns within `pdlModule`.
1537 
1538  /// The external functions referenced from within the PDL module.
1539  llvm::StringMap<PDLConstraintFunction> constraintFunctions;
1540  llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
1541 };
1542 
1543 //===----------------------------------------------------------------------===//
1544 // RewritePatternSet
1545 //===----------------------------------------------------------------------===//
1546 
1548  using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
1549 
1550 public:
1551  RewritePatternSet(MLIRContext *context) : context(context) {}
1552 
1553  /// Construct a RewritePatternSet populated with the given pattern.
1555  std::unique_ptr<RewritePattern> pattern)
1556  : context(context) {
1557  nativePatterns.emplace_back(std::move(pattern));
1558  }
1560  : context(pattern.getModule()->getContext()),
1561  pdlPatterns(std::move(pattern)) {}
1562 
1563  MLIRContext *getContext() const { return context; }
1564 
1565  /// Return the native patterns held in this list.
1566  NativePatternListT &getNativePatterns() { return nativePatterns; }
1567 
1568  /// Return the PDL patterns held in this list.
1569  PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
1570 
1571  /// Clear out all of the held patterns in this list.
1572  void clear() {
1573  nativePatterns.clear();
1574  pdlPatterns.clear();
1575  }
1576 
1577  //===--------------------------------------------------------------------===//
1578  // 'add' methods for adding patterns to the set.
1579  //===--------------------------------------------------------------------===//
1580 
1581  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
1582  /// the given arguments. Return a reference to `this` for chaining insertions.
1583  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
1584  template <typename... Ts, typename ConstructorArg,
1585  typename... ConstructorArgs,
1586  typename = std::enable_if_t<sizeof...(Ts) != 0>>
1587  RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
1588  // The following expands a call to emplace_back for each of the pattern
1589  // types 'Ts'.
1590  (addImpl<Ts>(/*debugLabels=*/llvm::None, std::forward<ConstructorArg>(arg),
1591  std::forward<ConstructorArgs>(args)...),
1592  ...);
1593  return *this;
1594  }
1595  /// An overload of the above `add` method that allows for attaching a set
1596  /// of debug labels to the attached patterns. This is useful for labeling
1597  /// groups of patterns that may be shared between multiple different
1598  /// passes/users.
1599  template <typename... Ts, typename ConstructorArg,
1600  typename... ConstructorArgs,
1601  typename = std::enable_if_t<sizeof...(Ts) != 0>>
1603  ConstructorArg &&arg,
1604  ConstructorArgs &&...args) {
1605  // The following expands a call to emplace_back for each of the pattern
1606  // types 'Ts'.
1607  (addImpl<Ts>(debugLabels, arg, args...), ...);
1608  return *this;
1609  }
1610 
1611  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
1612  /// `this` for chaining insertions.
1613  template <typename... Ts>
1615  (addImpl<Ts>(), ...);
1616  return *this;
1617  }
1618 
1619  /// Add the given native pattern to the pattern list. Return a reference to
1620  /// `this` for chaining insertions.
1621  RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
1622  nativePatterns.emplace_back(std::move(pattern));
1623  return *this;
1624  }
1625 
1626  /// Add the given PDL pattern to the pattern list. Return a reference to
1627  /// `this` for chaining insertions.
1629  pdlPatterns.mergeIn(std::move(pattern));
1630  return *this;
1631  }
1632 
1633  // Add a matchAndRewrite style pattern represented as a C function pointer.
1634  template <typename OpType>
1636  PatternRewriter &rewriter)) {
1637  struct FnPattern final : public OpRewritePattern<OpType> {
1638  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1639  MLIRContext *context)
1640  : OpRewritePattern<OpType>(context), implFn(implFn) {}
1641 
1642  LogicalResult matchAndRewrite(OpType op,
1643  PatternRewriter &rewriter) const override {
1644  return implFn(op, rewriter);
1645  }
1646 
1647  private:
1648  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1649  };
1650  add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
1651  return *this;
1652  }
1653 
1654  //===--------------------------------------------------------------------===//
1655  // Pattern Insertion
1656  //===--------------------------------------------------------------------===//
1657 
1658  // TODO: These are soft deprecated in favor of the 'add' methods above.
1659 
1660  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
1661  /// the given arguments. Return a reference to `this` for chaining insertions.
1662  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
1663  template <typename... Ts, typename ConstructorArg,
1664  typename... ConstructorArgs,
1665  typename = std::enable_if_t<sizeof...(Ts) != 0>>
1666  RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
1667  // The following expands a call to emplace_back for each of the pattern
1668  // types 'Ts'.
1669  (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), ...);
1670  return *this;
1671  }
1672 
1673  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
1674  /// `this` for chaining insertions.
1675  template <typename... Ts>
1677  (addImpl<Ts>(), ...);
1678  return *this;
1679  }
1680 
1681  /// Add the given native pattern to the pattern list. Return a reference to
1682  /// `this` for chaining insertions.
1683  RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
1684  nativePatterns.emplace_back(std::move(pattern));
1685  return *this;
1686  }
1687 
1688  /// Add the given PDL pattern to the pattern list. Return a reference to
1689  /// `this` for chaining insertions.
1691  pdlPatterns.mergeIn(std::move(pattern));
1692  return *this;
1693  }
1694 
1695  // Add a matchAndRewrite style pattern represented as a C function pointer.
1696  template <typename OpType>
1698  insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
1699  struct FnPattern final : public OpRewritePattern<OpType> {
1700  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1701  MLIRContext *context)
1702  : OpRewritePattern<OpType>(context), implFn(implFn) {
1703  this->setDebugName(llvm::getTypeName<FnPattern>());
1704  }
1705 
1706  LogicalResult matchAndRewrite(OpType op,
1707  PatternRewriter &rewriter) const override {
1708  return implFn(op, rewriter);
1709  }
1710 
1711  private:
1712  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1713  };
1714  add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
1715  return *this;
1716  }
1717 
1718 private:
1719  /// Add an instance of the pattern type 'T'. Return a reference to `this` for
1720  /// chaining insertions.
1721  template <typename T, typename... Args>
1723  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1724  std::unique_ptr<T> pattern =
1725  RewritePattern::create<T>(std::forward<Args>(args)...);
1726  pattern->addDebugLabels(debugLabels);
1727  nativePatterns.emplace_back(std::move(pattern));
1728  }
1729  template <typename T, typename... Args>
1731  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1732  // TODO: Add the provided labels to the PDL pattern when PDL supports
1733  // labels.
1734  pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1735  }
1736 
1737  MLIRContext *const context;
1738  NativePatternListT nativePatterns;
1739  PDLPatternModule pdlPatterns;
1740 };
1741 
1742 } // namespace mlir
1743 
1744 #endif // MLIR_IR_PATTERNMATCH_H
static std::string diag(llvm::Value &value)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const bool value
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:243
Location getUnknownLoc()
Definition: Builders.cpp:26
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:589
IRRewriter(const OpBuilder &builder)
Definition: PatternMatch.h:592
IRRewriter(MLIRContext *ctx)
Definition: PatternMatch.h:591
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class helps build Operations.
Definition: Builders.h:198
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:268
This class represents an operand of an operation.
Definition: Value.h:247
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:382
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:384
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:41
static OperationName getFromOpaquePointer(const void *pointer)
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
OpTy get() const
Allow accessing the internal op.
Definition: OwningOpRef.h:50
This class provides a base class for users implementing a type of pattern configuration.
Definition: PatternMatch.h:840
static TypeID getConfigID()
Return the type id used for this configuration.
Definition: PatternMatch.h:848
static bool classof(const PDLPatternConfig *config)
Support LLVM style casting.
Definition: PatternMatch.h:843
This class contains a set of configurations for a specific pattern.
Definition: PatternMatch.h:857
const T & get() const
Get the configuration defined by the given type.
Definition: PatternMatch.h:870
PDLPatternConfigSet(ConfigsT &&...configs)
Construct a set with the given configurations.
Definition: PatternMatch.h:863
const T * tryGet() const
Get the configuration defined by the given type, returns nullptr if the configuration does not exist.
Definition: PatternMatch.h:879
SmallVector< std::unique_ptr< PDLPatternConfig > > configs
The set of configurations for this pattern.
Definition: PatternMatch.h:909
void addConfig(T &&config)
Add a configuration to the set.
Definition: PatternMatch.h:900
void notifyRewriteBegin(PatternRewriter &rewriter)
Notify the configurations within this set at the beginning or end of a rewrite of a matched pattern.
Definition: PatternMatch.h:888
void notifyRewriteEnd(PatternRewriter &rewriter)
Definition: PatternMatch.h:892
An individual configuration for a pattern, which can be accessed by native functions via the PDLPatte...
Definition: PatternMatch.h:817
virtual ~PDLPatternConfig()=default
PDLPatternConfig(TypeID id)
Definition: PatternMatch.h:831
virtual void notifyRewriteEnd(PatternRewriter &rewriter)
Definition: PatternMatch.h:825
virtual void notifyRewriteBegin(PatternRewriter &rewriter)
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
Definition: PatternMatch.h:824
TypeID getTypeID() const
Return the TypeID that represents this configuration.
Definition: PatternMatch.h:828
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
void clear()
Clear out the patterns and functions within this module.
const llvm::StringMap< PDLRewriteFunction > & getRewriteFunctions() const
Return the set of the registered rewrite functions.
llvm::StringMap< PDLConstraintFunction > takeConstraintFunctions()
PDLPatternModule(OwningOpRef< ModuleOp > module, ConfigsT &&...patternConfigs)
void registerConstraintFunction(StringRef name, ConstraintFnT &&constraintFn)
PDLPatternModule(OwningOpRef< ModuleOp > module)
Construct a PDL pattern with the given module and configurations.
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn)
ModuleOp getModule()
Return the internal PDL module of this pattern.
const llvm::StringMap< PDLConstraintFunction > & getConstraintFunctions() const
Return the set of the registered constraint functions.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
SmallVector< std::unique_ptr< PDLPatternConfigSet > > takeConfigs()
Return the set of the registered pattern configs.
void mergeIn(PDLPatternModule &&other)
Merge the state in other into this pattern module.
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function with PDL.
llvm::StringMap< PDLRewriteFunction > takeRewriteFunctions()
DenseMap< Operation *, PDLPatternConfigSet * > takeConfigMap()
The class represents a list of PDL results, returned by a native rewrite method.
Definition: PatternMatch.h:737
void push_back(ValueTypeRange< OperandRange > value)
Definition: PatternMatch.h:758
void push_back(ResultRange value)
Definition: PatternMatch.h:784
void push_back(ValueRange value)
Push a new ValueRange onto the result list.
Definition: PatternMatch.h:771
PDLResultList(unsigned maxNumResults)
Create a new result list with the expected number of results.
Definition: PatternMatch.h:791
SmallVector< llvm::OwningArrayRef< Type > > allocatedTypeRanges
Memory allocated to store ranges in the result list whose lifetime was generated in the native functi...
Definition: PatternMatch.h:806
void push_back(ValueTypeRange< ResultRange > value)
Definition: PatternMatch.h:762
SmallVector< llvm::OwningArrayRef< Value > > allocatedValueRanges
Definition: PatternMatch.h:807
void push_back(Attribute value)
Push a new Attribute value onto the result list.
Definition: PatternMatch.h:740
SmallVector< TypeRange > typeRanges
Memory used to store ranges held by the list.
Definition: PatternMatch.h:802
SmallVector< PDLValue > results
The PDL results held by this list.
Definition: PatternMatch.h:800
void push_back(Type value)
Push a new Type onto the result list.
Definition: PatternMatch.h:746
void push_back(Operation *value)
Push a new Operation onto the result list.
Definition: PatternMatch.h:743
void push_back(OperandRange value)
Definition: PatternMatch.h:780
void push_back(Value value)
Push a new Value onto the result list.
Definition: PatternMatch.h:768
SmallVector< ValueRange > valueRanges
Definition: PatternMatch.h:803
void push_back(TypeRange value)
Push a new TypeRange onto the result list.
Definition: PatternMatch.h:749
Storage type of byte-code interpreter values.
Definition: PatternMatch.h:625
PDLValue(std::nullptr_t=nullptr)
Definition: PatternMatch.h:632
PDLValue(Type value)
Definition: PatternMatch.h:636
const void * getAsOpaquePointer() const
Get an opaque pointer to the value.
Definition: PatternMatch.h:667
PDLValue(Attribute value)
Definition: PatternMatch.h:633
PDLValue(Operation *value)
Definition: PatternMatch.h:635
Kind getKind() const
Return the kind of this value.
Definition: PatternMatch.h:673
ResultT dyn_cast() const
Attempt to dynamically cast this value to type T, returns null if this value is not an instance of T.
Definition: PatternMatch.h:654
bool isa() const
Returns true if the type of the held value is T.
Definition: PatternMatch.h:644
void print(raw_ostream &os) const
Print this value to the provided output stream.
PDLValue(TypeRange *value)
Definition: PatternMatch.h:637
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:628
T cast() const
Cast this value to type T, asserts if this value is not an instance of T.
Definition: PatternMatch.h:661
PDLValue(ValueRange *value)
Definition: PatternMatch.h:640
PDLValue(const PDLValue &other)=default
Construct a new PDL value.
PDLValue(Value value)
Definition: PatternMatch.h:638
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
PatternBenefit & operator=(const PatternBenefit &)=default
bool operator<(const PatternBenefit &rhs) const
Definition: PatternMatch.h:52
bool operator==(const PatternBenefit &rhs) const
Definition: PatternMatch.h:48
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:41
bool operator>=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:57
bool operator<=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:56
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
Definition: PatternMatch.h:42
bool operator!=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:51
PatternBenefit()=default
bool operator>(const PatternBenefit &rhs) const
Definition: PatternMatch.h:55
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:613
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
Optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
Definition: PatternMatch.h:110
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:127
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:132
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
Definition: PatternMatch.h:145
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
Definition: PatternMatch.h:200
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:88
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:121
Optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Definition: PatternMatch.h:101
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
Definition: PatternMatch.h:142
void addDebugLabels(StringRef label)
Definition: PatternMatch.h:151
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
Definition: PatternMatch.h:148
Optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:92
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:138
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType::iterator iterator
Definition: Region.h:52
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:230
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
RewritePatternSet(PDLPatternModule &&pattern)
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
void clear()
Clear out all of the held patterns in this list.
RewritePatternSet(MLIRContext *context)
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
Definition: PatternMatch.h:274
virtual ~RewritePattern()=default
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:262
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:517
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace it with to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Definition: PatternMatch.h:527
virtual void notifyRootReplaced(Operation *op, ValueRange replacement)
These are the callback methods that subclasses can choose to implement if they would like to be notif...
Definition: PatternMatch.h:556
void replaceOpWithIf(Operation *op, ValueRange newValues, llvm::unique_function< bool(OpOperand &) const > functor)
Definition: PatternMatch.h:430
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
Definition: PatternMatch.h:532
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the rewriter that the pattern failed to match the given operation, and provide a callback to p...
Definition: PatternMatch.h:567
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
RewriterBase(const OpBuilder &otherBuilder)
Definition: PatternMatch.h:544
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
virtual void notifyOperationRemoved(Operation *op)
This is called on an operation that a rewrite is removing, right before the operation is deleted.
Definition: PatternMatch.h:560
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:489
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Definition: PatternMatch.h:493
RewriterBase(MLIRContext *ctx)
Initialize the builder with this rewriter as the listener.
Definition: PatternMatch.h:543
~RewriterBase() override
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
Definition: PatternMatch.h:537
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:484
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:132
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:131
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
FnTraitsT::result_t processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Process the arguments of a native constraint and invoke it.
void assertArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Assert that the given PDLValues match the constraints defined by the arguments of the given function.
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Validate the given PDLValues match the constraints defined by the argument types of the given functio...
std::enable_if_t< std::is_same< typename FnTraitsT::result_t, void >::value, LogicalResult > processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, PDLResultList &, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Process the arguments of a native rewrite and invoke it.
static LogicalResult processResults(PatternRewriter &rewriter, PDLResultList &results, T &&value)
Store a single result within the result list.
std::enable_if_t< std::is_convertible< ConstraintFnT, PDLConstraintFunction >::value, PDLConstraintFunction > buildConstraintFn(ConstraintFnT &&constraintFn)
Build a constraint function from the given function ConstraintFnT.
constexpr bool always_false
A utility variable that always resolves to false.
Definition: PatternMatch.h:942
std::enable_if_t< std::is_convertible< RewriteFnT, PDLRewriteFunction >::value, PDLRewriteFunction > buildRewriteFn(RewriteFnT &&rewriteFn)
Build a rewrite function from the given function RewriteFnT.
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:919
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
Definition: PatternMatch.h:927
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:255
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:371
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:157
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:162
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:167
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
Definition: PatternMatch.h:317
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: PatternMatch.h:334
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: PatternMatch.h:321
virtual LogicalResult match(SourceOp op) const
Definition: PatternMatch.h:337
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:324
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
Definition: PatternMatch.h:340
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:327
This struct provides a simplified model for processing types that have "builtin" PDLValue support:
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
This struct provides a simplified model for processing types that inherit from builtin PDLValue types...
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT baseValue, size_t argIdx)
This struct provides a simplified model for processing types that are based on another type,...
Definition: PatternMatch.h:998
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT value, size_t argIdx)
Explicitly add the expected parent API to ensure the parent class implements the necessary API (and d...
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
static void processAsResult(PatternRewriter &, PDLResultList &results, OperandRange values)
static void processAsResult(PatternRewriter &, PDLResultList &results, ResultRange values)
static void processAsResult(PatternRewriter &, PDLResultList &results, SmallVector< Type, N > values)
static void processAsResult(PatternRewriter &, PDLResultList &results, SmallVector< Value, N > values)
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< OperandRange > types)
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< ResultRange > types)
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
This struct provides a convenient way to determine how to process a given type as either a PDL parame...
Definition: PatternMatch.h:978