MLIR  18.0.0git
PatternMatch.h
Go to the documentation of this file.
1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
16 #include <optional>
17 
18 namespace mlir {
19 
20 class PatternRewriter;
21 
22 //===----------------------------------------------------------------------===//
23 // PatternBenefit class
24 //===----------------------------------------------------------------------===//
25 
26 /// This class represents the benefit of a pattern match in a unitless scheme
27 /// that ranges from 0 (very little benefit) to 65K. The most common unit to
28 /// use here is the "number of operations matched" by the pattern.
29 ///
30 /// This also has a sentinel representation that can be used for patterns that
31 /// fail to match.
32 ///
34  enum { ImpossibleToMatchSentinel = 65535 };
35 
36 public:
37  PatternBenefit() = default;
38  PatternBenefit(unsigned benefit);
39  PatternBenefit(const PatternBenefit &) = default;
41 
43  bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
44 
45  /// If the corresponding pattern can match, return its benefit. If the
46  // corresponding pattern isImpossibleToMatch() then this aborts.
47  unsigned short getBenefit() const;
48 
49  bool operator==(const PatternBenefit &rhs) const {
50  return representation == rhs.representation;
51  }
52  bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
53  bool operator<(const PatternBenefit &rhs) const {
54  return representation < rhs.representation;
55  }
56  bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
57  bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
58  bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
59 
60 private:
61  unsigned short representation{ImpossibleToMatchSentinel};
62 };
63 
64 //===----------------------------------------------------------------------===//
65 // Pattern
66 //===----------------------------------------------------------------------===//
67 
68 /// This class contains all of the data related to a pattern, but does not
69 /// contain any methods or logic for the actual matching. This class is solely
70 /// used to interface with the metadata of a pattern, such as the benefit or
71 /// root operation.
72 class Pattern {
73  /// This enum represents the kind of value used to select the root operations
74  /// that match this pattern.
75  enum class RootKind {
76  /// The pattern root matches "any" operation.
77  Any,
78  /// The pattern root is matched using a concrete operation name.
80  /// The pattern root is matched using an interface ID.
81  InterfaceID,
82  /// The patter root is matched using a trait ID.
83  TraitID
84  };
85 
86 public:
87  /// Return a list of operations that may be generated when rewriting an
88  /// operation instance with this pattern.
89  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
90 
91  /// Return the root node that this pattern matches. Patterns that can match
92  /// multiple root types return std::nullopt.
93  std::optional<OperationName> getRootKind() const {
94  if (rootKind == RootKind::OperationName)
95  return OperationName::getFromOpaquePointer(rootValue);
96  return std::nullopt;
97  }
98 
99  /// Return the interface ID used to match the root operation of this pattern.
100  /// If the pattern does not use an interface ID for deciding the root match,
101  /// this returns std::nullopt.
102  std::optional<TypeID> getRootInterfaceID() const {
103  if (rootKind == RootKind::InterfaceID)
104  return TypeID::getFromOpaquePointer(rootValue);
105  return std::nullopt;
106  }
107 
108  /// Return the trait ID used to match the root operation of this pattern.
109  /// If the pattern does not use a trait ID for deciding the root match, this
110  /// returns std::nullopt.
111  std::optional<TypeID> getRootTraitID() const {
112  if (rootKind == RootKind::TraitID)
113  return TypeID::getFromOpaquePointer(rootValue);
114  return std::nullopt;
115  }
116 
117  /// Return the benefit (the inverse of "cost") of matching this pattern. The
118  /// benefit of a Pattern is always static - rewrites that may have dynamic
119  /// benefit can be instantiated multiple times (different Pattern instances)
120  /// for each benefit that they may return, and be guarded by different match
121  /// condition predicates.
122  PatternBenefit getBenefit() const { return benefit; }
123 
124  /// Returns true if this pattern is known to result in recursive application,
125  /// i.e. this pattern may generate IR that also matches this pattern, but is
126  /// known to bound the recursion. This signals to a rewrite driver that it is
127  /// safe to apply this pattern recursively to generated IR.
129  return contextAndHasBoundedRecursion.getInt();
130  }
131 
132  /// Return the MLIRContext used to create this pattern.
134  return contextAndHasBoundedRecursion.getPointer();
135  }
136 
137  /// Return a readable name for this pattern. This name should only be used for
138  /// debugging purposes, and may be empty.
139  StringRef getDebugName() const { return debugName; }
140 
141  /// Set the human readable debug name used for this pattern. This name will
142  /// only be used for debugging purposes.
143  void setDebugName(StringRef name) { debugName = name; }
144 
145  /// Return the set of debug labels attached to this pattern.
146  ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
147 
148  /// Add the provided debug labels to this pattern.
150  debugLabels.append(labels.begin(), labels.end());
151  }
152  void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
153 
154 protected:
155  /// This class acts as a special tag that makes the desire to match "any"
156  /// operation type explicit. This helps to avoid unnecessary usages of this
157  /// feature, and ensures that the user is making a conscious decision.
158  struct MatchAnyOpTypeTag {};
159  /// This class acts as a special tag that makes the desire to match any
160  /// operation that implements a given interface explicit. This helps to avoid
161  /// unnecessary usages of this feature, and ensures that the user is making a
162  /// conscious decision.
164  /// This class acts as a special tag that makes the desire to match any
165  /// operation that implements a given trait explicit. This helps to avoid
166  /// unnecessary usages of this feature, and ensures that the user is making a
167  /// conscious decision.
169 
170  /// Construct a pattern with a certain benefit that matches the operation
171  /// with the given root name.
172  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
173  ArrayRef<StringRef> generatedNames = {});
174  /// Construct a pattern that may match any operation type. `generatedNames`
175  /// contains the names of operations that may be generated during a successful
176  /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
177  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
178  /// always be supplied here.
179  Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
180  ArrayRef<StringRef> generatedNames = {});
181  /// Construct a pattern that may match any operation that implements the
182  /// interface defined by the provided `interfaceID`. `generatedNames` contains
183  /// the names of operations that may be generated during a successful rewrite.
184  /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
185  /// interface" behavior is what the user actually desired,
186  /// `MatchInterfaceOpTypeTag()` should always be supplied here.
187  Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
188  PatternBenefit benefit, MLIRContext *context,
189  ArrayRef<StringRef> generatedNames = {});
190  /// Construct a pattern that may match any operation that implements the
191  /// trait defined by the provided `traitID`. `generatedNames` contains the
192  /// names of operations that may be generated during a successful rewrite.
193  /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
194  /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
195  /// always be supplied here.
196  Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
197  MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
198 
199  /// Set the flag detailing if this pattern has bounded rewrite recursion or
200  /// not.
201  void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
202  contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
203  }
204 
205 private:
206  Pattern(const void *rootValue, RootKind rootKind,
207  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
208  MLIRContext *context);
209 
210  /// The value used to match the root operation of the pattern.
211  const void *rootValue;
212  RootKind rootKind;
213 
214  /// The expected benefit of matching this pattern.
215  const PatternBenefit benefit;
216 
217  /// The context this pattern was created from, and a boolean flag indicating
218  /// whether this pattern has bounded recursion or not.
219  llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
220 
221  /// A list of the potential operations that may be generated when rewriting
222  /// an op with this pattern.
223  SmallVector<OperationName, 2> generatedOps;
224 
225  /// A readable name for this pattern. May be empty.
226  StringRef debugName;
227 
228  /// The set of debug labels attached to this pattern.
229  SmallVector<StringRef, 0> debugLabels;
230 };
231 
232 //===----------------------------------------------------------------------===//
233 // RewritePattern
234 //===----------------------------------------------------------------------===//
235 
236 /// RewritePattern is the common base class for all DAG to DAG replacements.
237 /// There are two possible usages of this class:
238 /// * Multi-step RewritePattern with "match" and "rewrite"
239 /// - By overloading the "match" and "rewrite" functions, the user can
240 /// separate the concerns of matching and rewriting.
241 /// * Single-step RewritePattern with "matchAndRewrite"
242 /// - By overloading the "matchAndRewrite" function, the user can perform
243 /// the rewrite in the same call as the match.
244 ///
245 class RewritePattern : public Pattern {
246 public:
247  virtual ~RewritePattern() = default;
248 
249  /// Rewrite the IR rooted at the specified operation with the result of
250  /// this pattern, generating any new operations with the specified
251  /// builder. If an unexpected error is encountered (an internal
252  /// compiler error), it is emitted through the normal MLIR diagnostic
253  /// hooks and the IR is left in a valid state.
254  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
255 
256  /// Attempt to match against code rooted at the specified operation,
257  /// which is the same operation code as getRootKind().
258  virtual LogicalResult match(Operation *op) const;
259 
260  /// Attempt to match against code rooted at the specified operation,
261  /// which is the same operation code as getRootKind(). If successful, this
262  /// function will automatically perform the rewrite.
264  PatternRewriter &rewriter) const {
265  if (succeeded(match(op))) {
266  rewrite(op, rewriter);
267  return success();
268  }
269  return failure();
270  }
271 
272  /// This method provides a convenient interface for creating and initializing
273  /// derived rewrite patterns of the given type `T`.
274  template <typename T, typename... Args>
275  static std::unique_ptr<T> create(Args &&...args) {
276  std::unique_ptr<T> pattern =
277  std::make_unique<T>(std::forward<Args>(args)...);
278  initializePattern<T>(*pattern);
279 
280  // Set a default debug name if one wasn't provided.
281  if (pattern->getDebugName().empty())
282  pattern->setDebugName(llvm::getTypeName<T>());
283  return pattern;
284  }
285 
286 protected:
287  /// Inherit the base constructors from `Pattern`.
288  using Pattern::Pattern;
289 
290 private:
291  /// Trait to check if T provides a `getOperationName` method.
292  template <typename T, typename... Args>
293  using has_initialize = decltype(std::declval<T>().initialize());
294  template <typename T>
295  using detect_has_initialize = llvm::is_detected<has_initialize, T>;
296 
297  /// Initialize the derived pattern by calling its `initialize` method.
298  template <typename T>
299  static std::enable_if_t<detect_has_initialize<T>::value>
300  initializePattern(T &pattern) {
301  pattern.initialize();
302  }
303  /// Empty derived pattern initializer for patterns that do not have an
304  /// initialize method.
305  template <typename T>
306  static std::enable_if_t<!detect_has_initialize<T>::value>
307  initializePattern(T &) {}
308 
309  /// An anchor for the virtual table.
310  virtual void anchor();
311 };
312 
313 namespace detail {
314 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
315 /// allows for matching and rewriting against an instance of a derived operation
316 /// class or Interface.
317 template <typename SourceOp>
319  using RewritePattern::RewritePattern;
320 
321  /// Wrappers around the RewritePattern methods that pass the derived op type.
322  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
323  rewrite(cast<SourceOp>(op), rewriter);
324  }
325  LogicalResult match(Operation *op) const final {
326  return match(cast<SourceOp>(op));
327  }
329  PatternRewriter &rewriter) const final {
330  return matchAndRewrite(cast<SourceOp>(op), rewriter);
331  }
332 
333  /// Rewrite and Match methods that operate on the SourceOp type. These must be
334  /// overridden by the derived pattern class.
335  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
336  llvm_unreachable("must override rewrite or matchAndRewrite");
337  }
338  virtual LogicalResult match(SourceOp op) const {
339  llvm_unreachable("must override match or matchAndRewrite");
340  }
341  virtual LogicalResult matchAndRewrite(SourceOp op,
342  PatternRewriter &rewriter) const {
343  if (succeeded(match(op))) {
344  rewrite(op, rewriter);
345  return success();
346  }
347  return failure();
348  }
349 };
350 } // namespace detail
351 
352 /// OpRewritePattern is a wrapper around RewritePattern that allows for
353 /// matching and rewriting against an instance of a derived operation class as
354 /// opposed to a raw Operation.
355 template <typename SourceOp>
357  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
358  /// Patterns must specify the root operation name they match against, and can
359  /// also specify the benefit of the pattern matching and a list of generated
360  /// ops.
362  ArrayRef<StringRef> generatedNames = {})
364  SourceOp::getOperationName(), benefit, context, generatedNames) {}
365 };
366 
367 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
368 /// matching and rewriting against an instance of an operation interface instead
369 /// of a raw Operation.
370 template <typename SourceOp>
372  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
374  : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
375  Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
376  benefit, context) {}
377 };
378 
379 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
380 /// matching and rewriting against instances of an operation that possess a
381 /// given trait.
382 template <template <typename> class TraitType>
384 public:
387  benefit, context) {}
388 };
389 
390 //===----------------------------------------------------------------------===//
391 // RewriterBase
392 //===----------------------------------------------------------------------===//
393 
394 /// This class coordinates the application of a rewrite on a set of IR,
395 /// providing a way for clients to track mutations and create new operations.
396 /// This class serves as a common API for IR mutation between pattern rewrites
397 /// and non-pattern rewrites, and facilitates the development of shared
398 /// IR transformation utilities.
399 class RewriterBase : public OpBuilder {
400 public:
401  struct Listener : public OpBuilder::Listener {
403  : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
404 
405  /// Notify the listener that the specified operation was modified in-place.
406  virtual void notifyOperationModified(Operation *op) {}
407 
408  /// Notify the listener that the specified operation is about to be replaced
409  /// with another operation. This is called before the uses of the old
410  /// operation have been changed.
411  ///
412  /// By default, this function calls the "operation replaced with values"
413  /// notification.
415  Operation *replacement) {
416  notifyOperationReplaced(op, replacement->getResults());
417  }
418 
419  /// Notify the listener that the specified operation is about to be replaced
420  /// with the a range of values, potentially produced by other operations.
421  /// This is called before the uses of the operation have been changed.
423  ValueRange replacement) {}
424 
425  /// Notify the listener that the specified operation is about to be erased.
426  /// At this point, the operation has zero uses.
427  virtual void notifyOperationRemoved(Operation *op) {}
428 
429  /// Notify the listener that the pattern failed to match the given
430  /// operation, and provide a callback to populate a diagnostic with the
431  /// reason why the failure occurred. This method allows for derived
432  /// listeners to optionally hook into the reason why a rewrite failed, and
433  /// display it to users.
434  virtual LogicalResult
436  function_ref<void(Diagnostic &)> reasonCallback) {
437  return failure();
438  }
439 
440  static bool classof(const OpBuilder::Listener *base);
441  };
442 
443  /// A listener that forwards all notifications to another listener. This
444  /// struct can be used as a base to create listener chains, so that multiple
445  /// listeners can be notified of IR changes.
447  ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
448 
449  void notifyOperationInserted(Operation *op) override {
450  listener->notifyOperationInserted(op);
451  }
452  void notifyBlockCreated(Block *block) override {
453  listener->notifyBlockCreated(block);
454  }
455  void notifyOperationModified(Operation *op) override {
456  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
457  rewriteListener->notifyOperationModified(op);
458  }
459  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
460  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
461  rewriteListener->notifyOperationReplaced(op, newOp);
462  }
464  ValueRange replacement) override {
465  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
466  rewriteListener->notifyOperationReplaced(op, replacement);
467  }
468  void notifyOperationRemoved(Operation *op) override {
469  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
470  rewriteListener->notifyOperationRemoved(op);
471  }
473  Location loc,
474  function_ref<void(Diagnostic &)> reasonCallback) override {
475  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
476  return rewriteListener->notifyMatchFailure(loc, reasonCallback);
477  return failure();
478  }
479 
480  private:
481  OpBuilder::Listener *listener;
482  };
483 
484  /// Move the blocks that belong to "region" before the given position in
485  /// another region "parent". The two regions must be different. The caller
486  /// is responsible for creating or updating the operation transferring flow
487  /// of control to the region and passing it the correct block arguments.
488  virtual void inlineRegionBefore(Region &region, Region &parent,
489  Region::iterator before);
490  void inlineRegionBefore(Region &region, Block *before);
491 
492  /// Clone the blocks that belong to "region" before the given position in
493  /// another region "parent". The two regions must be different. The caller is
494  /// responsible for creating or updating the operation transferring flow of
495  /// control to the region and passing it the correct block arguments.
496  virtual void cloneRegionBefore(Region &region, Region &parent,
497  Region::iterator before, IRMapping &mapping);
498  void cloneRegionBefore(Region &region, Region &parent,
499  Region::iterator before);
500  void cloneRegionBefore(Region &region, Block *before);
501 
502  /// This method replaces the uses of the results of `op` with the values in
503  /// `newValues` when the provided `functor` returns true for a specific use.
504  /// The number of values in `newValues` is required to match the number of
505  /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
506  /// the uses of `op` were replaced. Note that in some rewriters, the given
507  /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
508  /// As such, the function should not capture by reference and instead use
509  /// value capture as necessary.
510  virtual void
511  replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
512  llvm::unique_function<bool(OpOperand &) const> functor);
513  void replaceOpWithIf(Operation *op, ValueRange newValues,
514  llvm::unique_function<bool(OpOperand &) const> functor) {
515  replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
516  std::move(functor));
517  }
518 
519  /// This method replaces the uses of the results of `op` with the values in
520  /// `newValues` when a use is nested within the given `block`. The number of
521  /// values in `newValues` is required to match the number of results of `op`.
522  /// If all uses of this operation are replaced, the operation is erased.
523  void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
524  bool *allUsesReplaced = nullptr);
525 
526  /// This method replaces the results of the operation with the specified list
527  /// of values. The number of provided values must match the number of results
528  /// of the operation. The replaced op is erased.
529  virtual void replaceOp(Operation *op, ValueRange newValues);
530 
531  /// This method replaces the results of the operation with the specified
532  /// new op (replacement). The number of results of the two operations must
533  /// match. The replaced op is erased.
534  virtual void replaceOp(Operation *op, Operation *newOp);
535 
536  /// Replaces the result op with a new op that is created without verification.
537  /// The result values of the two ops must be the same types.
538  template <typename OpTy, typename... Args>
539  OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
540  auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
541  replaceOp(op, newOp.getOperation());
542  return newOp;
543  }
544 
545  /// This method erases an operation that is known to have no uses.
546  virtual void eraseOp(Operation *op);
547 
548  /// This method erases all operations in a block.
549  virtual void eraseBlock(Block *block);
550 
551  /// Inline the operations of block 'source' into block 'dest' before the given
552  /// position. The source block will be deleted and must have no uses.
553  /// 'argValues' is used to replace the block arguments of 'source'.
554  ///
555  /// If the source block is inserted at the end of the dest block, the dest
556  /// block must have no successors. Similarly, if the source block is inserted
557  /// somewhere in the middle (or beginning) of the dest block, the source block
558  /// must have no successors. Otherwise, the resulting IR would have
559  /// unreachable operations.
560  virtual void inlineBlockBefore(Block *source, Block *dest,
561  Block::iterator before,
562  ValueRange argValues = std::nullopt);
563 
564  /// Inline the operations of block 'source' before the operation 'op'. The
565  /// source block will be deleted and must have no uses. 'argValues' is used to
566  /// replace the block arguments of 'source'
567  ///
568  /// The source block must have no successors. Otherwise, the resulting IR
569  /// would have unreachable operations.
570  void inlineBlockBefore(Block *source, Operation *op,
571  ValueRange argValues = std::nullopt);
572 
573  /// Inline the operations of block 'source' into the end of block 'dest'. The
574  /// source block will be deleted and must have no uses. 'argValues' is used to
575  /// replace the block arguments of 'source'
576  ///
577  /// The dest block must have no successors. Otherwise, the resulting IR would
578  /// have unreachable operation.
579  void mergeBlocks(Block *source, Block *dest,
580  ValueRange argValues = std::nullopt);
581 
582  /// Split the operations starting at "before" (inclusive) out of the given
583  /// block into a new block, and return it.
584  virtual Block *splitBlock(Block *block, Block::iterator before);
585 
586  /// This method is used to notify the rewriter that an in-place operation
587  /// modification is about to happen. A call to this function *must* be
588  /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
589  /// This is a minor efficiency win (it avoids creating a new operation and
590  /// removing the old one) but also often allows simpler code in the client.
591  virtual void startRootUpdate(Operation *op) {}
592 
593  /// This method is used to signal the end of a root update on the given
594  /// operation. This can only be called on operations that were provided to a
595  /// call to `startRootUpdate`.
596  virtual void finalizeRootUpdate(Operation *op);
597 
598  /// This method cancels a pending root update. This can only be called on
599  /// operations that were provided to a call to `startRootUpdate`.
600  virtual void cancelRootUpdate(Operation *op) {}
601 
602  /// This method is a utility wrapper around a root update of an operation. It
603  /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
604  /// callable.
605  template <typename CallableT>
606  void updateRootInPlace(Operation *root, CallableT &&callable) {
607  startRootUpdate(root);
608  callable();
609  finalizeRootUpdate(root);
610  }
611 
612  /// Find uses of `from` and replace them with `to`. It also marks every
613  /// modified uses and notifies the rewriter that an in-place operation
614  /// modification is about to happen.
615  void replaceAllUsesWith(Value from, Value to) {
616  return replaceAllUsesWith(from.getImpl(), to);
617  }
618  template <typename OperandType, typename ValueT>
620  for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
621  Operation *op = operand.getOwner();
622  updateRootInPlace(op, [&]() { operand.set(to); });
623  }
624  }
626  assert(from.size() == to.size() && "incorrect number of replacements");
627  for (auto it : llvm::zip(from, to))
628  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
629  }
630 
631  /// Find uses of `from` and replace them with `to` if the `functor` returns
632  /// true. It also marks every modified uses and notifies the rewriter that an
633  /// in-place operation modification is about to happen.
634  void replaceUsesWithIf(Value from, Value to,
635  function_ref<bool(OpOperand &)> functor);
637  function_ref<bool(OpOperand &)> functor) {
638  assert(from.size() == to.size() && "incorrect number of replacements");
639  for (auto it : llvm::zip(from, to))
640  replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
641  }
642 
643  /// Find uses of `from` and replace them with `to` except if the user is
644  /// `exceptedUser`. It also marks every modified uses and notifies the
645  /// rewriter that an in-place operation modification is about to happen.
646  void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
647  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
648  Operation *user = use.getOwner();
649  return user != exceptedUser;
650  });
651  }
652 
653  /// Used to notify the rewriter that the IR failed to be rewritten because of
654  /// a match failure, and provide a callback to populate a diagnostic with the
655  /// reason why the failure occurred. This method allows for derived rewriters
656  /// to optionally hook into the reason why a rewrite failed, and display it to
657  /// users.
658  template <typename CallbackT>
659  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
660  notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
661 #ifndef NDEBUG
662  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
663  return rewriteListener->notifyMatchFailure(
664  loc, function_ref<void(Diagnostic &)>(reasonCallback));
665  return failure();
666 #else
667  return failure();
668 #endif
669  }
670  template <typename CallbackT>
671  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
672  notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
673  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
674  return rewriteListener->notifyMatchFailure(
675  op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
676  return failure();
677  }
678  template <typename ArgT>
679  LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
680  return notifyMatchFailure(std::forward<ArgT>(arg),
681  [&](Diagnostic &diag) { diag << msg; });
682  }
683  template <typename ArgT>
684  LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
685  return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
686  }
687 
688 protected:
689  /// Initialize the builder.
690  explicit RewriterBase(MLIRContext *ctx,
691  OpBuilder::Listener *listener = nullptr)
692  : OpBuilder(ctx, listener) {}
693  explicit RewriterBase(const OpBuilder &otherBuilder)
694  : OpBuilder(otherBuilder) {}
695  virtual ~RewriterBase();
696 
697 private:
698  void operator=(const RewriterBase &) = delete;
699  RewriterBase(const RewriterBase &) = delete;
700 };
701 
702 //===----------------------------------------------------------------------===//
703 // IRRewriter
704 //===----------------------------------------------------------------------===//
705 
706 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
707 /// providing a way to keep track of the mutations made to the IR. This class
708 /// should only be used in situations where another `RewriterBase` instance,
709 /// such as a `PatternRewriter`, is not available.
710 class IRRewriter : public RewriterBase {
711 public:
713  : RewriterBase(ctx, listener) {}
714  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
715 };
716 
717 //===----------------------------------------------------------------------===//
718 // PatternRewriter
719 //===----------------------------------------------------------------------===//
720 
721 /// A special type of `RewriterBase` that coordinates the application of a
722 /// rewrite pattern on the current IR being matched, providing a way to keep
723 /// track of any mutations made. This class should be used to perform all
724 /// necessary IR mutations within a rewrite pattern, as the pattern driver may
725 /// be tracking various state that would be invalidated when a mutation takes
726 /// place.
728 public:
730 
731  /// A hook used to indicate if the pattern rewriter can recover from failure
732  /// during the rewrite stage of a pattern. For example, if the pattern
733  /// rewriter supports rollback, it may progress smoothly even if IR was
734  /// changed during the rewrite.
735  virtual bool canRecoverFromRewriteFailure() const { return false; }
736 };
737 
738 //===----------------------------------------------------------------------===//
739 // PDL Patterns
740 //===----------------------------------------------------------------------===//
741 
742 //===----------------------------------------------------------------------===//
743 // PDLValue
744 
745 /// Storage type of byte-code interpreter values. These are passed to constraint
746 /// functions as arguments.
747 class PDLValue {
748 public:
749  /// The underlying kind of a PDL value.
751 
752  /// Construct a new PDL value.
753  PDLValue(const PDLValue &other) = default;
754  PDLValue(std::nullptr_t = nullptr) {}
756  : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
757  PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
758  PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
759  PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
761  : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
762  PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
763 
764  /// Returns true if the type of the held value is `T`.
765  template <typename T>
766  bool isa() const {
767  assert(value && "isa<> used on a null value");
768  return kind == getKindOf<T>();
769  }
770 
771  /// Attempt to dynamically cast this value to type `T`, returns null if this
772  /// value is not an instance of `T`.
773  template <typename T,
774  typename ResultT = std::conditional_t<
775  std::is_convertible<T, bool>::value, T, std::optional<T>>>
776  ResultT dyn_cast() const {
777  return isa<T>() ? castImpl<T>() : ResultT();
778  }
779 
780  /// Cast this value to type `T`, asserts if this value is not an instance of
781  /// `T`.
782  template <typename T>
783  T cast() const {
784  assert(isa<T>() && "expected value to be of type `T`");
785  return castImpl<T>();
786  }
787 
788  /// Get an opaque pointer to the value.
789  const void *getAsOpaquePointer() const { return value; }
790 
791  /// Return if this value is null or not.
792  explicit operator bool() const { return value; }
793 
794  /// Return the kind of this value.
795  Kind getKind() const { return kind; }
796 
797  /// Print this value to the provided output stream.
798  void print(raw_ostream &os) const;
799 
800  /// Print the specified value kind to an output stream.
801  static void print(raw_ostream &os, Kind kind);
802 
803 private:
804  /// Find the index of a given type in a range of other types.
805  template <typename...>
806  struct index_of_t;
807  template <typename T, typename... R>
808  struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
809  template <typename T, typename F, typename... R>
810  struct index_of_t<T, F, R...>
811  : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
812 
813  /// Return the kind used for the given T.
814  template <typename T>
815  static Kind getKindOf() {
816  return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
817  TypeRange, Value, ValueRange>::value);
818  }
819 
820  /// The internal implementation of `cast`, that returns the underlying value
821  /// as the given type `T`.
822  template <typename T>
823  std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
824  castImpl() const {
825  return T::getFromOpaquePointer(value);
826  }
827  template <typename T>
828  std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
829  castImpl() const {
830  return *reinterpret_cast<T *>(const_cast<void *>(value));
831  }
832  template <typename T>
833  std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
834  return reinterpret_cast<T>(const_cast<void *>(value));
835  }
836 
837  /// The internal opaque representation of a PDLValue.
838  const void *value{nullptr};
839  /// The kind of the opaque value.
840  Kind kind{Kind::Attribute};
841 };
842 
843 inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
844  value.print(os);
845  return os;
846 }
847 
848 inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
849  PDLValue::print(os, kind);
850  return os;
851 }
852 
853 //===----------------------------------------------------------------------===//
854 // PDLResultList
855 
856 /// The class represents a list of PDL results, returned by a native rewrite
857 /// method. It provides the mechanism with which to pass PDLValues back to the
858 /// PDL bytecode.
860 public:
861  /// Push a new Attribute value onto the result list.
862  void push_back(Attribute value) { results.push_back(value); }
863 
864  /// Push a new Operation onto the result list.
865  void push_back(Operation *value) { results.push_back(value); }
866 
867  /// Push a new Type onto the result list.
868  void push_back(Type value) { results.push_back(value); }
869 
870  /// Push a new TypeRange onto the result list.
871  void push_back(TypeRange value) {
872  // The lifetime of a TypeRange can't be guaranteed, so we'll need to
873  // allocate a storage for it.
874  llvm::OwningArrayRef<Type> storage(value.size());
875  llvm::copy(value, storage.begin());
876  allocatedTypeRanges.emplace_back(std::move(storage));
877  typeRanges.push_back(allocatedTypeRanges.back());
878  results.push_back(&typeRanges.back());
879  }
881  typeRanges.push_back(value);
882  results.push_back(&typeRanges.back());
883  }
885  typeRanges.push_back(value);
886  results.push_back(&typeRanges.back());
887  }
888 
889  /// Push a new Value onto the result list.
890  void push_back(Value value) { results.push_back(value); }
891 
892  /// Push a new ValueRange onto the result list.
893  void push_back(ValueRange value) {
894  // The lifetime of a ValueRange can't be guaranteed, so we'll need to
895  // allocate a storage for it.
896  llvm::OwningArrayRef<Value> storage(value.size());
897  llvm::copy(value, storage.begin());
898  allocatedValueRanges.emplace_back(std::move(storage));
899  valueRanges.push_back(allocatedValueRanges.back());
900  results.push_back(&valueRanges.back());
901  }
902  void push_back(OperandRange value) {
903  valueRanges.push_back(value);
904  results.push_back(&valueRanges.back());
905  }
906  void push_back(ResultRange value) {
907  valueRanges.push_back(value);
908  results.push_back(&valueRanges.back());
909  }
910 
911 protected:
912  /// Create a new result list with the expected number of results.
913  PDLResultList(unsigned maxNumResults) {
914  // For now just reserve enough space for all of the results. We could do
915  // separate counts per range type, but it isn't really worth it unless there
916  // are a "large" number of results.
917  typeRanges.reserve(maxNumResults);
918  valueRanges.reserve(maxNumResults);
919  }
920 
921  /// The PDL results held by this list.
923  /// Memory used to store ranges held by the list.
926  /// Memory allocated to store ranges in the result list whose lifetime was
927  /// generated in the native function.
930 };
931 
932 //===----------------------------------------------------------------------===//
933 // PDLPatternConfig
934 
935 /// An individual configuration for a pattern, which can be accessed by native
936 /// functions via the PDLPatternConfigSet. This allows for injecting additional
937 /// configuration into PDL patterns that is specific to certain compilation
938 /// flows.
940 public:
941  virtual ~PDLPatternConfig() = default;
942 
943  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
944  /// pattern. These can be used to setup any specific state necessary for the
945  /// rewrite.
946  virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
947  virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
948 
949  /// Return the TypeID that represents this configuration.
950  TypeID getTypeID() const { return id; }
951 
952 protected:
953  PDLPatternConfig(TypeID id) : id(id) {}
954 
955 private:
956  TypeID id;
957 };
958 
959 /// This class provides a base class for users implementing a type of pattern
960 /// configuration.
961 template <typename T>
963 public:
964  /// Support LLVM style casting.
965  static bool classof(const PDLPatternConfig *config) {
966  return config->getTypeID() == getConfigID();
967  }
968 
969  /// Return the type id used for this configuration.
970  static TypeID getConfigID() { return TypeID::get<T>(); }
971 
972 protected:
974 };
975 
976 /// This class contains a set of configurations for a specific pattern.
977 /// Configurations are uniqued by TypeID, meaning that only one configuration of
978 /// each type is allowed.
980 public:
981  PDLPatternConfigSet() = default;
982 
983  /// Construct a set with the given configurations.
984  template <typename... ConfigsT>
985  PDLPatternConfigSet(ConfigsT &&...configs) {
986  (addConfig(std::forward<ConfigsT>(configs)), ...);
987  }
988 
989  /// Get the configuration defined by the given type. Asserts that the
990  /// configuration of the provided type exists.
991  template <typename T>
992  const T &get() const {
993  const T *config = tryGet<T>();
994  assert(config && "configuration not found");
995  return *config;
996  }
997 
998  /// Get the configuration defined by the given type, returns nullptr if the
999  /// configuration does not exist.
1000  template <typename T>
1001  const T *tryGet() const {
1002  for (const auto &configIt : configs)
1003  if (const T *config = dyn_cast<T>(configIt.get()))
1004  return config;
1005  return nullptr;
1006  }
1007 
1008  /// Notify the configurations within this set at the beginning or end of a
1009  /// rewrite of a matched pattern.
1011  for (const auto &config : configs)
1012  config->notifyRewriteBegin(rewriter);
1013  }
1015  for (const auto &config : configs)
1016  config->notifyRewriteEnd(rewriter);
1017  }
1018 
1019 protected:
1020  /// Add a configuration to the set.
1021  template <typename T>
1022  void addConfig(T &&config) {
1023  assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
1024  configs.emplace_back(
1025  std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
1026  }
1027 
1028  /// The set of configurations for this pattern. This uses a vector instead of
1029  /// a map with the expectation that the number of configurations per set is
1030  /// small (<= 1).
1032 };
1033 
1034 //===----------------------------------------------------------------------===//
1035 // PDLPatternModule
1036 
1037 /// A generic PDL pattern constraint function. This function applies a
1038 /// constraint to a given set of opaque PDLValue entities. Returns success if
1039 /// the constraint successfully held, failure otherwise.
1042 /// A native PDL rewrite function. This function performs a rewrite on the
1043 /// given set of values. Any results from this rewrite that should be passed
1044 /// back to PDL should be added to the provided result list. This method is only
1045 /// invoked when the corresponding match was successful. Returns failure if an
1046 /// invariant of the rewrite was broken (certain rewriters may recover from
1047 /// partial pattern application).
1048 using PDLRewriteFunction = std::function<LogicalResult(
1050 
1051 namespace detail {
1052 namespace pdl_function_builder {
1053 /// A utility variable that always resolves to false. This is useful for static
1054 /// asserts that are always false, but only should fire in certain templated
1055 /// constructs. For example, if a templated function should never be called, the
1056 /// function could be defined as:
1057 ///
1058 /// template <typename T>
1059 /// void foo() {
1060 /// static_assert(always_false<T>, "This function should never be called");
1061 /// }
1062 ///
1063 template <class... T>
1064 constexpr bool always_false = false;
1065 
1066 //===----------------------------------------------------------------------===//
1067 // PDL Function Builder: Type Processing
1068 //===----------------------------------------------------------------------===//
1069 
1070 /// This struct provides a convenient way to determine how to process a given
1071 /// type as either a PDL parameter, or a result value. This allows for
1072 /// supporting complex types in constraint and rewrite functions, without
1073 /// requiring the user to hand-write the necessary glue code themselves.
1074 /// Specializations of this class should implement the following methods to
1075 /// enable support as a PDL argument or result type:
1076 ///
1077 /// static LogicalResult verifyAsArg(
1078 /// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
1079 /// size_t argIdx);
1080 ///
1081 /// * This method verifies that the given PDLValue is valid for use as a
1082 /// value of `T`.
1083 ///
1084 /// static T processAsArg(PDLValue pdlValue);
1085 ///
1086 /// * This method processes the given PDLValue as a value of `T`.
1087 ///
1088 /// static void processAsResult(PatternRewriter &, PDLResultList &results,
1089 /// const T &value);
1090 ///
1091 /// * This method processes the given value of `T` as the result of a
1092 /// function invocation. The method should package the value into an
1093 /// appropriate form and append it to the given result list.
1094 ///
1095 /// If the type `T` is based on a higher order value, consider using
1096 /// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
1097 /// the implementation.
1098 ///
1099 template <typename T, typename Enable = void>
1101 
1102 /// This struct provides a simplified model for processing types that are based
1103 /// on another type, e.g. APInt is based on the handling for IntegerAttr. This
1104 /// allows for building the necessary processing functions on top of the base
1105 /// value instead of a PDLValue. Derived users should implement the following
1106 /// (which subsume the ProcessPDLValue variants):
1107 ///
1108 /// static LogicalResult verifyAsArg(
1109 /// function_ref<LogicalResult(const Twine &)> errorFn,
1110 /// const BaseT &baseValue, size_t argIdx);
1111 ///
1112 /// * This method verifies that the given PDLValue is valid for use as a
1113 /// value of `T`.
1114 ///
1115 /// static T processAsArg(BaseT baseValue);
1116 ///
1117 /// * This method processes the given base value as a value of `T`.
1118 ///
1119 template <typename T, typename BaseT>
1121  static LogicalResult
1122  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
1123  PDLValue pdlValue, size_t argIdx) {
1124  // Verify the base class before continuing.
1125  if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
1126  return failure();
1128  errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
1129  }
1130  static T processAsArg(PDLValue pdlValue) {
1133  }
1134 
1135  /// Explicitly add the expected parent API to ensure the parent class
1136  /// implements the necessary API (and doesn't implicitly inherit it from
1137  /// somewhere else).
1138  static LogicalResult
1139  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
1140  size_t argIdx) {
1141  return success();
1142  }
1143  static T processAsArg(BaseT baseValue);
1144 };
1145 
1146 /// This struct provides a simplified model for processing types that have
1147 /// "builtin" PDLValue support:
1148 /// * Attribute, Operation *, Type, TypeRange, ValueRange
1149 template <typename T>
1151  static LogicalResult
1152  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
1153  PDLValue pdlValue, size_t argIdx) {
1154  if (pdlValue)
1155  return success();
1156  return errorFn("expected a non-null value for argument " + Twine(argIdx) +
1157  " of type: " + llvm::getTypeName<T>());
1158  }
1159 
1160  static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
1162  T value) {
1163  results.push_back(value);
1164  }
1165 };
1166 
1167 /// This struct provides a simplified model for processing types that inherit
1168 /// from builtin PDLValue types. For example, derived attributes like
1169 /// IntegerAttr, derived types like IntegerType, derived operations like
1170 /// ModuleOp, Interfaces, etc.
1171 template <typename T, typename BaseT>
1173  static LogicalResult
1174  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
1175  BaseT baseValue, size_t argIdx) {
1176  return TypeSwitch<BaseT, LogicalResult>(baseValue)
1177  .Case([&](T) { return success(); })
1178  .Default([&](BaseT) {
1179  return errorFn("expected argument " + Twine(argIdx) +
1180  " to be of type: " + llvm::getTypeName<T>());
1181  });
1182  }
1184 
1185  static T processAsArg(BaseT baseValue) {
1186  return baseValue.template cast<T>();
1187  }
1189 
1191  T value) {
1192  results.push_back(value);
1193  }
1194 };
1195 
1196 //===----------------------------------------------------------------------===//
1197 // Attribute
1198 
1199 template <>
1200 struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
1201 template <typename T>
1203  std::enable_if_t<std::is_base_of<Attribute, T>::value>>
1204  : public ProcessDerivedPDLValue<T, Attribute> {};
1205 
1206 /// Handling for various Attribute value types.
1207 template <>
1208 struct ProcessPDLValue<StringRef>
1209  : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
1210  static StringRef processAsArg(StringAttr value) { return value.getValue(); }
1212 
1213  static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
1214  StringRef value) {
1215  results.push_back(rewriter.getStringAttr(value));
1216  }
1217 };
1218 template <>
1219 struct ProcessPDLValue<std::string>
1220  : public ProcessPDLValueBasedOn<std::string, StringAttr> {
1221  template <typename T>
1222  static std::string processAsArg(T value) {
1223  static_assert(always_false<T>,
1224  "`std::string` arguments require a string copy, use "
1225  "`StringRef` for string-like arguments instead");
1226  return {};
1227  }
1228  static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
1229  StringRef value) {
1230  results.push_back(rewriter.getStringAttr(value));
1231  }
1232 };
1233 
1234 //===----------------------------------------------------------------------===//
1235 // Operation
1236 
1237 template <>
1240 template <typename T>
1241 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
1242  : public ProcessDerivedPDLValue<T, Operation *> {
1243  static T processAsArg(Operation *value) { return cast<T>(value); }
1244 };
1245 
1246 //===----------------------------------------------------------------------===//
1247 // Type
1248 
1249 template <>
1250 struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
1251 template <typename T>
1252 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
1253  : public ProcessDerivedPDLValue<T, Type> {};
1254 
1255 //===----------------------------------------------------------------------===//
1256 // TypeRange
1257 
1258 template <>
1259 struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
1260 template <>
1264  results.push_back(types);
1265  }
1266 };
1267 template <>
1271  results.push_back(types);
1272  }
1273 };
1274 template <unsigned N>
1277  SmallVector<Type, N> values) {
1278  results.push_back(TypeRange(values));
1279  }
1280 };
1281 
1282 //===----------------------------------------------------------------------===//
1283 // Value
1284 
1285 template <>
1286 struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
1287 
1288 //===----------------------------------------------------------------------===//
1289 // ValueRange
1290 
1291 template <>
1292 struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
1293 };
1294 template <>
1297  OperandRange values) {
1298  results.push_back(values);
1299  }
1300 };
1301 template <>
1304  ResultRange values) {
1305  results.push_back(values);
1306  }
1307 };
1308 template <unsigned N>
1311  SmallVector<Value, N> values) {
1312  results.push_back(ValueRange(values));
1313  }
1314 };
1315 
1316 //===----------------------------------------------------------------------===//
1317 // PDL Function Builder: Argument Handling
1318 //===----------------------------------------------------------------------===//
1319 
1320 /// Validate the given PDLValues match the constraints defined by the argument
1321 /// types of the given function. In the case of failure, a match failure
1322 /// diagnostic is emitted.
1323 /// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
1324 /// does not currently preserve Constraint application ordering.
1325 template <typename PDLFnT, std::size_t... I>
1327  std::index_sequence<I...>) {
1328  using FnTraitsT = llvm::function_traits<PDLFnT>;
1329 
1330  auto errorFn = [&](const Twine &msg) {
1331  return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
1332  };
1333  return success(
1334  (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1335  verifyAsArg(errorFn, values[I], I)) &&
1336  ...));
1337 }
1338 
1339 /// Assert that the given PDLValues match the constraints defined by the
1340 /// arguments of the given function. In the case of failure, a fatal error
1341 /// is emitted.
1342 template <typename PDLFnT, std::size_t... I>
1344  std::index_sequence<I...>) {
1345  // We only want to do verification in debug builds, same as with `assert`.
1346 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1347  using FnTraitsT = llvm::function_traits<PDLFnT>;
1348  auto errorFn = [&](const Twine &msg) -> LogicalResult {
1349  llvm::report_fatal_error(msg);
1350  };
1351  (void)errorFn;
1352  assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1353  verifyAsArg(errorFn, values[I], I)) &&
1354  ...));
1355 #endif
1356  (void)values;
1357 }
1358 
1359 //===----------------------------------------------------------------------===//
1360 // PDL Function Builder: Results Handling
1361 //===----------------------------------------------------------------------===//
1362 
1363 /// Store a single result within the result list.
1364 template <typename T>
1366  PDLResultList &results, T &&value) {
1367  ProcessPDLValue<T>::processAsResult(rewriter, results,
1368  std::forward<T>(value));
1369  return success();
1370 }
1371 
1372 /// Store a std::pair<> as individual results within the result list.
1373 template <typename T1, typename T2>
1375  PDLResultList &results,
1376  std::pair<T1, T2> &&pair) {
1377  if (failed(processResults(rewriter, results, std::move(pair.first))) ||
1378  failed(processResults(rewriter, results, std::move(pair.second))))
1379  return failure();
1380  return success();
1381 }
1382 
1383 /// Store a std::tuple<> as individual results within the result list.
1384 template <typename... Ts>
1386  PDLResultList &results,
1387  std::tuple<Ts...> &&tuple) {
1388  auto applyFn = [&](auto &&...args) {
1389  return (succeeded(processResults(rewriter, results, std::move(args))) &&
1390  ...);
1391  };
1392  return success(std::apply(applyFn, std::move(tuple)));
1393 }
1394 
1395 /// Handle LogicalResult propagation.
1397  PDLResultList &results,
1398  LogicalResult &&result) {
1399  return result;
1400 }
1401 template <typename T>
1403  PDLResultList &results,
1404  FailureOr<T> &&result) {
1405  if (failed(result))
1406  return failure();
1407  return processResults(rewriter, results, std::move(*result));
1408 }
1409 
1410 //===----------------------------------------------------------------------===//
1411 // PDL Constraint Builder
1412 //===----------------------------------------------------------------------===//
1413 
1414 /// Process the arguments of a native constraint and invoke it.
1415 template <typename PDLFnT, std::size_t... I,
1416  typename FnTraitsT = llvm::function_traits<PDLFnT>>
1417 typename FnTraitsT::result_t
1419  ArrayRef<PDLValue> values,
1420  std::index_sequence<I...>) {
1421  return fn(
1422  rewriter,
1423  (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1424  values[I]))...);
1425 }
1426 
1427 /// Build a constraint function from the given function `ConstraintFnT`. This
1428 /// allows for enabling the user to define simpler, more direct constraint
1429 /// functions without needing to handle the low-level PDL goop.
1430 ///
1431 /// If the constraint function is already in the correct form, we just forward
1432 /// it directly.
1433 template <typename ConstraintFnT>
1434 std::enable_if_t<
1435  std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1437 buildConstraintFn(ConstraintFnT &&constraintFn) {
1438  return std::forward<ConstraintFnT>(constraintFn);
1439 }
1440 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
1441 /// we desire.
1442 template <typename ConstraintFnT>
1443 std::enable_if_t<
1444  !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1446 buildConstraintFn(ConstraintFnT &&constraintFn) {
1447  return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
1448  PatternRewriter &rewriter,
1449  ArrayRef<PDLValue> values) -> LogicalResult {
1450  auto argIndices = std::make_index_sequence<
1451  llvm::function_traits<ConstraintFnT>::num_args - 1>();
1452  if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
1453  return failure();
1454  return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
1455  argIndices);
1456  };
1457 }
1458 
1459 //===----------------------------------------------------------------------===//
1460 // PDL Rewrite Builder
1461 //===----------------------------------------------------------------------===//
1462 
1463 /// Process the arguments of a native rewrite and invoke it.
1464 /// This overload handles the case of no return values.
1465 template <typename PDLFnT, std::size_t... I,
1466  typename FnTraitsT = llvm::function_traits<PDLFnT>>
1467 std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
1468  LogicalResult>
1471  std::index_sequence<I...>) {
1472  fn(rewriter,
1473  (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1474  values[I]))...);
1475  return success();
1476 }
1477 /// This overload handles the case of return values, which need to be packaged
1478 /// into the result list.
1479 template <typename PDLFnT, std::size_t... I,
1480  typename FnTraitsT = llvm::function_traits<PDLFnT>>
1481 std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
1482  LogicalResult>
1484  PDLResultList &results, ArrayRef<PDLValue> values,
1485  std::index_sequence<I...>) {
1486  return processResults(
1487  rewriter, results,
1488  fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1489  processAsArg(values[I]))...));
1490  (void)values;
1491 }
1492 
1493 /// Build a rewrite function from the given function `RewriteFnT`. This
1494 /// allows for enabling the user to define simpler, more direct rewrite
1495 /// functions without needing to handle the low-level PDL goop.
1496 ///
1497 /// If the rewrite function is already in the correct form, we just forward
1498 /// it directly.
1499 template <typename RewriteFnT>
1500 std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1502 buildRewriteFn(RewriteFnT &&rewriteFn) {
1503  return std::forward<RewriteFnT>(rewriteFn);
1504 }
1505 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
1506 /// we desire.
1507 template <typename RewriteFnT>
1508 std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1510 buildRewriteFn(RewriteFnT &&rewriteFn) {
1511  return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
1512  PatternRewriter &rewriter, PDLResultList &results,
1513  ArrayRef<PDLValue> values) {
1514  auto argIndices =
1515  std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1516  1>();
1517  assertArgs<RewriteFnT>(rewriter, values, argIndices);
1518  return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
1519  argIndices);
1520  };
1521 }
1522 
1523 } // namespace pdl_function_builder
1524 } // namespace detail
1525 
1526 //===----------------------------------------------------------------------===//
1527 // PDLPatternModule
1528 
1529 /// This class contains all of the necessary data for a set of PDL patterns, or
1530 /// pattern rewrites specified in the form of the PDL dialect. This PDL module
1531 /// contained by this pattern may contain any number of `pdl.pattern`
1532 /// operations.
1534 public:
1535  PDLPatternModule() = default;
1536 
1537  /// Construct a PDL pattern with the given module and configurations.
1539  : pdlModule(std::move(module)) {}
1540  template <typename... ConfigsT>
1541  PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
1542  : PDLPatternModule(std::move(module)) {
1543  auto configSet = std::make_unique<PDLPatternConfigSet>(
1544  std::forward<ConfigsT>(patternConfigs)...);
1545  attachConfigToPatterns(*pdlModule, *configSet);
1546  configs.emplace_back(std::move(configSet));
1547  }
1548 
1549  /// Merge the state in `other` into this pattern module.
1550  void mergeIn(PDLPatternModule &&other);
1551 
1552  /// Return the internal PDL module of this pattern.
1553  ModuleOp getModule() { return pdlModule.get(); }
1554 
1555  //===--------------------------------------------------------------------===//
1556  // Function Registry
1557 
1558  /// Register a constraint function with PDL. A constraint function may be
1559  /// specified in one of two ways:
1560  ///
1561  /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
1562  ///
1563  /// In this overload the arguments of the constraint function are passed via
1564  /// the low-level PDLValue form.
1565  ///
1566  /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
1567  ///
1568  /// In this form the arguments of the constraint function are passed via the
1569  /// expected high level C++ type. In this form, the framework will
1570  /// automatically unwrap PDLValues and convert them to the expected ValueTs.
1571  /// For example, if the constraint function accepts a `Operation *`, the
1572  /// framework will automatically cast the input PDLValue. In the case of a
1573  /// `StringRef`, the framework will automatically unwrap the argument as a
1574  /// StringAttr and pass the underlying string value. To see the full list of
1575  /// supported types, or to see how to add handling for custom types, view
1576  /// the definition of `ProcessPDLValue` above.
1577  void registerConstraintFunction(StringRef name,
1578  PDLConstraintFunction constraintFn);
1579  template <typename ConstraintFnT>
1580  void registerConstraintFunction(StringRef name,
1581  ConstraintFnT &&constraintFn) {
1584  std::forward<ConstraintFnT>(constraintFn)));
1585  }
1586 
1587  /// Register a rewrite function with PDL. A rewrite function may be specified
1588  /// in one of two ways:
1589  ///
1590  /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
1591  ///
1592  /// In this overload the arguments of the constraint function are passed via
1593  /// the low-level PDLValue form, and the results are manually appended to
1594  /// the given result list.
1595  ///
1596  /// * `ResultT (PatternRewriter &, ValueTs... values)`
1597  ///
1598  /// In this form the arguments and result of the rewrite function are passed
1599  /// via the expected high level C++ type. In this form, the framework will
1600  /// automatically unwrap the PDLValues arguments and convert them to the
1601  /// expected ValueTs. It will also automatically handle the processing and
1602  /// packaging of the result value to the result list. For example, if the
1603  /// rewrite function takes a `Operation *`, the framework will automatically
1604  /// cast the input PDLValue. In the case of a `StringRef`, the framework
1605  /// will automatically unwrap the argument as a StringAttr and pass the
1606  /// underlying string value. In the reverse case, if the rewrite returns a
1607  /// StringRef or std::string, it will automatically package this as a
1608  /// StringAttr and append it to the result list. To see the full list of
1609  /// supported types, or to see how to add handling for custom types, view
1610  /// the definition of `ProcessPDLValue` above.
1611  void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
1612  template <typename RewriteFnT>
1613  void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
1615  std::forward<RewriteFnT>(rewriteFn)));
1616  }
1617 
1618  /// Return the set of the registered constraint functions.
1619  const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
1620  return constraintFunctions;
1621  }
1622  llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
1623  return constraintFunctions;
1624  }
1625  /// Return the set of the registered rewrite functions.
1626  const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
1627  return rewriteFunctions;
1628  }
1629  llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
1630  return rewriteFunctions;
1631  }
1632 
1633  /// Return the set of the registered pattern configs.
1635  return std::move(configs);
1636  }
1638  return std::move(configMap);
1639  }
1640 
1641  /// Clear out the patterns and functions within this module.
1642  void clear() {
1643  pdlModule = nullptr;
1644  constraintFunctions.clear();
1645  rewriteFunctions.clear();
1646  }
1647 
1648 private:
1649  /// Attach the given pattern config set to the patterns defined within the
1650  /// given module.
1651  void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
1652 
1653  /// The module containing the `pdl.pattern` operations.
1654  OwningOpRef<ModuleOp> pdlModule;
1655 
1656  /// The set of configuration sets referenced by patterns within `pdlModule`.
1659 
1660  /// The external functions referenced from within the PDL module.
1661  llvm::StringMap<PDLConstraintFunction> constraintFunctions;
1662  llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
1663 };
1664 
1665 //===----------------------------------------------------------------------===//
1666 // RewritePatternSet
1667 //===----------------------------------------------------------------------===//
1668 
1670  using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
1671 
1672 public:
1673  RewritePatternSet(MLIRContext *context) : context(context) {}
1674 
1675  /// Construct a RewritePatternSet populated with the given pattern.
1677  std::unique_ptr<RewritePattern> pattern)
1678  : context(context) {
1679  nativePatterns.emplace_back(std::move(pattern));
1680  }
1682  : context(pattern.getModule()->getContext()),
1683  pdlPatterns(std::move(pattern)) {}
1684 
1685  MLIRContext *getContext() const { return context; }
1686 
1687  /// Return the native patterns held in this list.
1688  NativePatternListT &getNativePatterns() { return nativePatterns; }
1689 
1690  /// Return the PDL patterns held in this list.
1691  PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
1692 
1693  /// Clear out all of the held patterns in this list.
1694  void clear() {
1695  nativePatterns.clear();
1696  pdlPatterns.clear();
1697  }
1698 
1699  //===--------------------------------------------------------------------===//
1700  // 'add' methods for adding patterns to the set.
1701  //===--------------------------------------------------------------------===//
1702 
1703  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
1704  /// the given arguments. Return a reference to `this` for chaining insertions.
1705  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
1706  template <typename... Ts, typename ConstructorArg,
1707  typename... ConstructorArgs,
1708  typename = std::enable_if_t<sizeof...(Ts) != 0>>
1709  RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
1710  // The following expands a call to emplace_back for each of the pattern
1711  // types 'Ts'.
1712  (addImpl<Ts>(/*debugLabels=*/std::nullopt,
1713  std::forward<ConstructorArg>(arg),
1714  std::forward<ConstructorArgs>(args)...),
1715  ...);
1716  return *this;
1717  }
1718  /// An overload of the above `add` method that allows for attaching a set
1719  /// of debug labels to the attached patterns. This is useful for labeling
1720  /// groups of patterns that may be shared between multiple different
1721  /// passes/users.
1722  template <typename... Ts, typename ConstructorArg,
1723  typename... ConstructorArgs,
1724  typename = std::enable_if_t<sizeof...(Ts) != 0>>
1726  ConstructorArg &&arg,
1727  ConstructorArgs &&...args) {
1728  // The following expands a call to emplace_back for each of the pattern
1729  // types 'Ts'.
1730  (addImpl<Ts>(debugLabels, arg, args...), ...);
1731  return *this;
1732  }
1733 
1734  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
1735  /// `this` for chaining insertions.
1736  template <typename... Ts>
1738  (addImpl<Ts>(), ...);
1739  return *this;
1740  }
1741 
1742  /// Add the given native pattern to the pattern list. Return a reference to
1743  /// `this` for chaining insertions.
1744  RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
1745  nativePatterns.emplace_back(std::move(pattern));
1746  return *this;
1747  }
1748 
1749  /// Add the given PDL pattern to the pattern list. Return a reference to
1750  /// `this` for chaining insertions.
1752  pdlPatterns.mergeIn(std::move(pattern));
1753  return *this;
1754  }
1755 
1756  // Add a matchAndRewrite style pattern represented as a C function pointer.
1757  template <typename OpType>
1759  add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1760  PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
1761  struct FnPattern final : public OpRewritePattern<OpType> {
1762  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1763  MLIRContext *context, PatternBenefit benefit,
1764  ArrayRef<StringRef> generatedNames)
1765  : OpRewritePattern<OpType>(context, benefit, generatedNames),
1766  implFn(implFn) {}
1767 
1768  LogicalResult matchAndRewrite(OpType op,
1769  PatternRewriter &rewriter) const override {
1770  return implFn(op, rewriter);
1771  }
1772 
1773  private:
1774  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1775  };
1776  add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
1777  generatedNames));
1778  return *this;
1779  }
1780 
1781  //===--------------------------------------------------------------------===//
1782  // Pattern Insertion
1783  //===--------------------------------------------------------------------===//
1784 
1785  // TODO: These are soft deprecated in favor of the 'add' methods above.
1786 
1787  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
1788  /// the given arguments. Return a reference to `this` for chaining insertions.
1789  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
1790  template <typename... Ts, typename ConstructorArg,
1791  typename... ConstructorArgs,
1792  typename = std::enable_if_t<sizeof...(Ts) != 0>>
1793  RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
1794  // The following expands a call to emplace_back for each of the pattern
1795  // types 'Ts'.
1796  (addImpl<Ts>(/*debugLabels=*/std::nullopt, arg, args...), ...);
1797  return *this;
1798  }
1799 
1800  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
1801  /// `this` for chaining insertions.
1802  template <typename... Ts>
1804  (addImpl<Ts>(), ...);
1805  return *this;
1806  }
1807 
1808  /// Add the given native pattern to the pattern list. Return a reference to
1809  /// `this` for chaining insertions.
1810  RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
1811  nativePatterns.emplace_back(std::move(pattern));
1812  return *this;
1813  }
1814 
1815  /// Add the given PDL pattern to the pattern list. Return a reference to
1816  /// `this` for chaining insertions.
1818  pdlPatterns.mergeIn(std::move(pattern));
1819  return *this;
1820  }
1821 
1822  // Add a matchAndRewrite style pattern represented as a C function pointer.
1823  template <typename OpType>
1825  insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
1826  struct FnPattern final : public OpRewritePattern<OpType> {
1827  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1828  MLIRContext *context)
1829  : OpRewritePattern<OpType>(context), implFn(implFn) {
1830  this->setDebugName(llvm::getTypeName<FnPattern>());
1831  }
1832 
1833  LogicalResult matchAndRewrite(OpType op,
1834  PatternRewriter &rewriter) const override {
1835  return implFn(op, rewriter);
1836  }
1837 
1838  private:
1839  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1840  };
1841  add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
1842  return *this;
1843  }
1844 
1845 private:
1846  /// Add an instance of the pattern type 'T'. Return a reference to `this` for
1847  /// chaining insertions.
1848  template <typename T, typename... Args>
1849  std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
1850  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1851  std::unique_ptr<T> pattern =
1852  RewritePattern::create<T>(std::forward<Args>(args)...);
1853  pattern->addDebugLabels(debugLabels);
1854  nativePatterns.emplace_back(std::move(pattern));
1855  }
1856  template <typename T, typename... Args>
1857  std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
1858  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1859  // TODO: Add the provided labels to the PDL pattern when PDL supports
1860  // labels.
1861  pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1862  }
1863 
1864  MLIRContext *const context;
1865  NativePatternListT nativePatterns;
1866  PDLPatternModule pdlPatterns;
1867 };
1868 
1869 } // namespace mlir
1870 
1871 #endif // MLIR_IR_PATTERNMATCH_H
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
Location getUnknownLoc()
Definition: Builders.cpp:27
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a single IR object that contains a use list.
Definition: UseDefLists.h:185
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: UseDefLists.h:243
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:712
IRRewriter(const OpBuilder &builder)
Definition: PatternMatch.h:714
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:572
This class represents an operand of an operation.
Definition: Value.h:261
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:383
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:385
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:410
OpTy get() const
Allow accessing the internal op.
Definition: OwningOpRef.h:50
This class provides a base class for users implementing a type of pattern configuration.
Definition: PatternMatch.h:962
static TypeID getConfigID()
Return the type id used for this configuration.
Definition: PatternMatch.h:970
static bool classof(const PDLPatternConfig *config)
Support LLVM style casting.
Definition: PatternMatch.h:965
This class contains a set of configurations for a specific pattern.
Definition: PatternMatch.h:979
const T & get() const
Get the configuration defined by the given type.
Definition: PatternMatch.h:992
PDLPatternConfigSet(ConfigsT &&...configs)
Construct a set with the given configurations.
Definition: PatternMatch.h:985
const T * tryGet() const
Get the configuration defined by the given type, returns nullptr if the configuration does not exist.
SmallVector< std::unique_ptr< PDLPatternConfig > > configs
The set of configurations for this pattern.
void addConfig(T &&config)
Add a configuration to the set.
void notifyRewriteBegin(PatternRewriter &rewriter)
Notify the configurations within this set at the beginning or end of a rewrite of a matched pattern.
void notifyRewriteEnd(PatternRewriter &rewriter)
An individual configuration for a pattern, which can be accessed by native functions via the PDLPatte...
Definition: PatternMatch.h:939
virtual ~PDLPatternConfig()=default
PDLPatternConfig(TypeID id)
Definition: PatternMatch.h:953
virtual void notifyRewriteEnd(PatternRewriter &rewriter)
Definition: PatternMatch.h:947
virtual void notifyRewriteBegin(PatternRewriter &rewriter)
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
Definition: PatternMatch.h:946
TypeID getTypeID() const
Return the TypeID that represents this configuration.
Definition: PatternMatch.h:950
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
void clear()
Clear out the patterns and functions within this module.
const llvm::StringMap< PDLRewriteFunction > & getRewriteFunctions() const
Return the set of the registered rewrite functions.
llvm::StringMap< PDLConstraintFunction > takeConstraintFunctions()
PDLPatternModule(OwningOpRef< ModuleOp > module, ConfigsT &&...patternConfigs)
void registerConstraintFunction(StringRef name, ConstraintFnT &&constraintFn)
PDLPatternModule(OwningOpRef< ModuleOp > module)
Construct a PDL pattern with the given module and configurations.
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn)
ModuleOp getModule()
Return the internal PDL module of this pattern.
const llvm::StringMap< PDLConstraintFunction > & getConstraintFunctions() const
Return the set of the registered constraint functions.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
SmallVector< std::unique_ptr< PDLPatternConfigSet > > takeConfigs()
Return the set of the registered pattern configs.
void mergeIn(PDLPatternModule &&other)
Merge the state in other into this pattern module.
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function with PDL.
llvm::StringMap< PDLRewriteFunction > takeRewriteFunctions()
DenseMap< Operation *, PDLPatternConfigSet * > takeConfigMap()
The class represents a list of PDL results, returned by a native rewrite method.
Definition: PatternMatch.h:859
void push_back(ValueTypeRange< OperandRange > value)
Definition: PatternMatch.h:880
void push_back(ResultRange value)
Definition: PatternMatch.h:906
void push_back(ValueRange value)
Push a new ValueRange onto the result list.
Definition: PatternMatch.h:893
PDLResultList(unsigned maxNumResults)
Create a new result list with the expected number of results.
Definition: PatternMatch.h:913
SmallVector< llvm::OwningArrayRef< Type > > allocatedTypeRanges
Memory allocated to store ranges in the result list whose lifetime was generated in the native functi...
Definition: PatternMatch.h:928
void push_back(ValueTypeRange< ResultRange > value)
Definition: PatternMatch.h:884
SmallVector< llvm::OwningArrayRef< Value > > allocatedValueRanges
Definition: PatternMatch.h:929
void push_back(Attribute value)
Push a new Attribute value onto the result list.
Definition: PatternMatch.h:862
SmallVector< TypeRange > typeRanges
Memory used to store ranges held by the list.
Definition: PatternMatch.h:924
SmallVector< PDLValue > results
The PDL results held by this list.
Definition: PatternMatch.h:922
void push_back(Type value)
Push a new Type onto the result list.
Definition: PatternMatch.h:868
void push_back(Operation *value)
Push a new Operation onto the result list.
Definition: PatternMatch.h:865
void push_back(OperandRange value)
Definition: PatternMatch.h:902
void push_back(Value value)
Push a new Value onto the result list.
Definition: PatternMatch.h:890
SmallVector< ValueRange > valueRanges
Definition: PatternMatch.h:925
void push_back(TypeRange value)
Push a new TypeRange onto the result list.
Definition: PatternMatch.h:871
Storage type of byte-code interpreter values.
Definition: PatternMatch.h:747
PDLValue(std::nullptr_t=nullptr)
Definition: PatternMatch.h:754
PDLValue(Type value)
Definition: PatternMatch.h:758
const void * getAsOpaquePointer() const
Get an opaque pointer to the value.
Definition: PatternMatch.h:789
PDLValue(Attribute value)
Definition: PatternMatch.h:755
PDLValue(Operation *value)
Definition: PatternMatch.h:757
Kind getKind() const
Return the kind of this value.
Definition: PatternMatch.h:795
ResultT dyn_cast() const
Attempt to dynamically cast this value to type T, returns null if this value is not an instance of T.
Definition: PatternMatch.h:776
bool isa() const
Returns true if the type of the held value is T.
Definition: PatternMatch.h:766
void print(raw_ostream &os) const
Print this value to the provided output stream.
PDLValue(TypeRange *value)
Definition: PatternMatch.h:759
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:750
T cast() const
Cast this value to type T, asserts if this value is not an instance of T.
Definition: PatternMatch.h:783
PDLValue(ValueRange *value)
Definition: PatternMatch.h:762
PDLValue(const PDLValue &other)=default
Construct a new PDL value.
PDLValue(Value value)
Definition: PatternMatch.h:760
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
PatternBenefit & operator=(const PatternBenefit &)=default
bool operator<(const PatternBenefit &rhs) const
Definition: PatternMatch.h:53
bool operator==(const PatternBenefit &rhs) const
Definition: PatternMatch.h:49
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:42
bool operator>=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:58
bool operator<=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:57
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
Definition: PatternMatch.h:43
bool operator!=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:52
PatternBenefit()=default
bool operator>(const PatternBenefit &rhs) const
Definition: PatternMatch.h:56
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:735
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Definition: PatternMatch.h:102
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:128
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:93
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:133
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
Definition: PatternMatch.h:146
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
Definition: PatternMatch.h:201
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:89
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:122
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
Definition: PatternMatch.h:111
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
Definition: PatternMatch.h:143
void addDebugLabels(StringRef label)
Definition: PatternMatch.h:152
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
Definition: PatternMatch.h:149
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:139
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType::iterator iterator
Definition: Region.h:52
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:233
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
RewritePatternSet(PDLPatternModule &&pattern)
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
void clear()
Clear out all of the held patterns in this list.
RewritePatternSet(MLIRContext *context)
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
Definition: PatternMatch.h:275
virtual ~RewritePattern()=default
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:263
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Definition: PatternMatch.h:672
void replaceOpWithIf(Operation *op, ValueRange newValues, llvm::unique_function< bool(OpOperand &) const > functor)
Definition: PatternMatch.h:513
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
Definition: PatternMatch.h:679
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
RewriterBase(const OpBuilder &otherBuilder)
Definition: PatternMatch.h:693
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(ValueRange from, ValueRange to)
Definition: PatternMatch.h:625
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:646
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Definition: PatternMatch.h:690
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Definition: PatternMatch.h:600
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
Definition: PatternMatch.h:684
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:591
void replaceAllUsesWith(IRObjectWithUseList< OperandType > *from, ValueT &&to)
Definition: PatternMatch.h:619
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceUsesWithIf(ValueRange from, ValueRange to, function_ref< bool(OpOperand &)> functor)
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:132
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:131
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
detail::ValueImpl * getImpl() const
Definition: Value.h:241
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
FnTraitsT::result_t processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Process the arguments of a native constraint and invoke it.
void assertArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Assert that the given PDLValues match the constraints defined by the arguments of the given function.
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Validate the given PDLValues match the constraints defined by the argument types of the given functio...
std::enable_if_t< std::is_same< typename FnTraitsT::result_t, void >::value, LogicalResult > processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, PDLResultList &, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Process the arguments of a native rewrite and invoke it.
static LogicalResult processResults(PatternRewriter &rewriter, PDLResultList &results, T &&value)
Store a single result within the result list.
std::enable_if_t< std::is_convertible< ConstraintFnT, PDLConstraintFunction >::value, PDLConstraintFunction > buildConstraintFn(ConstraintFnT &&constraintFn)
Build a constraint function from the given function ConstraintFnT.
constexpr bool always_false
A utility variable that always resolves to false.
std::enable_if_t< std::is_convertible< RewriteFnT, PDLRewriteFunction >::value, PDLRewriteFunction > buildRewriteFn(RewriteFnT &&rewriteFn)
Build a rewrite function from the given function RewriteFnT.
@ Type
An inlay hint that for a type annotation.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Base class for listeners.
Definition: Builders.h:262
Kind
The kind of listener.
Definition: Builders.h:264
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
virtual void notifyBlockCreated(Block *block)
Notification handler for when a block is created using the builder.
Definition: Builders.h:294
virtual void notifyOperationInserted(Operation *op)
Notification handler for when an operation is inserted into the builder.
Definition: Builders.h:290
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:158
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:163
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:168
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:446
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:455
void notifyOperationInserted(Operation *op) override
Notification handler for when an operation is inserted into the builder.
Definition: PatternMatch.h:449
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that the specified operation is about to be replaced with another operation.
Definition: PatternMatch.h:459
void notifyOperationRemoved(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:468
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
Definition: PatternMatch.h:472
void notifyBlockCreated(Block *block) override
Notification handler for when a block is created using the builder.
Definition: PatternMatch.h:452
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:447
void notifyOperationReplaced(Operation *op, ValueRange replacement) override
Notify the listener that the specified operation is about to be replaced with the a range of values,...
Definition: PatternMatch.h:463
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:406
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that the specified operation is about to be replaced with another operation.
Definition: PatternMatch.h:414
virtual LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
Definition: PatternMatch.h:435
virtual void notifyOperationRemoved(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:427
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that the specified operation is about to be replaced with the a range of values,...
Definition: PatternMatch.h:422
static bool classof(const OpBuilder::Listener *base)
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
Definition: PatternMatch.h:318
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: PatternMatch.h:335
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: PatternMatch.h:322
virtual LogicalResult match(SourceOp op) const
Definition: PatternMatch.h:338
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:325
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
Definition: PatternMatch.h:341
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:328
This struct provides a simplified model for processing types that have "builtin" PDLValue support:
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
This struct provides a simplified model for processing types that inherit from builtin PDLValue types...
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT baseValue, size_t argIdx)
This struct provides a simplified model for processing types that are based on another type,...
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT value, size_t argIdx)
Explicitly add the expected parent API to ensure the parent class implements the necessary API (and d...
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
static void processAsResult(PatternRewriter &, PDLResultList &results, OperandRange values)
static void processAsResult(PatternRewriter &, PDLResultList &results, ResultRange values)
static void processAsResult(PatternRewriter &, PDLResultList &results, SmallVector< Type, N > values)
static void processAsResult(PatternRewriter &, PDLResultList &results, SmallVector< Value, N > values)
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< OperandRange > types)
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< ResultRange > types)
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
This struct provides a convenient way to determine how to process a given type as either a PDL parame...