MLIR  22.0.0git
PatternMatch.h
Go to the documentation of this file.
1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
16 #include <optional>
17 
19 namespace mlir {
20 
21 class PatternRewriter;
22 
23 //===----------------------------------------------------------------------===//
24 // PatternBenefit class
25 //===----------------------------------------------------------------------===//
26 
27 /// This class represents the benefit of a pattern match in a unitless scheme
28 /// that ranges from 0 (very little benefit) to 65K. The most common unit to
29 /// use here is the "number of operations matched" by the pattern.
30 ///
31 /// This also has a sentinel representation that can be used for patterns that
32 /// fail to match.
33 ///
35  enum { ImpossibleToMatchSentinel = 65535 };
36 
37 public:
38  PatternBenefit() = default;
39  PatternBenefit(unsigned benefit);
40  PatternBenefit(const PatternBenefit &) = default;
42 
44  bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
45 
46  /// If the corresponding pattern can match, return its benefit. If the
47  // corresponding pattern isImpossibleToMatch() then this aborts.
48  unsigned short getBenefit() const;
49 
50  bool operator==(const PatternBenefit &rhs) const {
51  return representation == rhs.representation;
52  }
53  bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
54  bool operator<(const PatternBenefit &rhs) const {
55  return representation < rhs.representation;
56  }
57  bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
58  bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
59  bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
60 
61 private:
62  unsigned short representation{ImpossibleToMatchSentinel};
63 };
64 
65 //===----------------------------------------------------------------------===//
66 // Pattern
67 //===----------------------------------------------------------------------===//
68 
69 /// This class contains all of the data related to a pattern, but does not
70 /// contain any methods or logic for the actual matching. This class is solely
71 /// used to interface with the metadata of a pattern, such as the benefit or
72 /// root operation.
73 class Pattern {
74  /// This enum represents the kind of value used to select the root operations
75  /// that match this pattern.
76  enum class RootKind {
77  /// The pattern root matches "any" operation.
78  Any,
79  /// The pattern root is matched using a concrete operation name.
81  /// The pattern root is matched using an interface ID.
82  InterfaceID,
83  /// The patter root is matched using a trait ID.
84  TraitID
85  };
86 
87 public:
88  /// Return a list of operations that may be generated when rewriting an
89  /// operation instance with this pattern.
90  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
91 
92  /// Return the root node that this pattern matches. Patterns that can match
93  /// multiple root types return std::nullopt.
94  std::optional<OperationName> getRootKind() const {
95  if (rootKind == RootKind::OperationName)
96  return OperationName::getFromOpaquePointer(rootValue);
97  return std::nullopt;
98  }
99 
100  /// Return the interface ID used to match the root operation of this pattern.
101  /// If the pattern does not use an interface ID for deciding the root match,
102  /// this returns std::nullopt.
103  std::optional<TypeID> getRootInterfaceID() const {
104  if (rootKind == RootKind::InterfaceID)
105  return TypeID::getFromOpaquePointer(rootValue);
106  return std::nullopt;
107  }
108 
109  /// Return the trait ID used to match the root operation of this pattern.
110  /// If the pattern does not use a trait ID for deciding the root match, this
111  /// returns std::nullopt.
112  std::optional<TypeID> getRootTraitID() const {
113  if (rootKind == RootKind::TraitID)
114  return TypeID::getFromOpaquePointer(rootValue);
115  return std::nullopt;
116  }
117 
118  /// Return the benefit (the inverse of "cost") of matching this pattern. The
119  /// benefit of a Pattern is always static - rewrites that may have dynamic
120  /// benefit can be instantiated multiple times (different Pattern instances)
121  /// for each benefit that they may return, and be guarded by different match
122  /// condition predicates.
123  PatternBenefit getBenefit() const { return benefit; }
124 
125  /// Returns true if this pattern is known to result in recursive application,
126  /// i.e. this pattern may generate IR that also matches this pattern, but is
127  /// known to bound the recursion. This signals to a rewrite driver that it is
128  /// safe to apply this pattern recursively to generated IR.
130  return contextAndHasBoundedRecursion.getInt();
131  }
132 
133  /// Return the MLIRContext used to create this pattern.
135  return contextAndHasBoundedRecursion.getPointer();
136  }
137 
138  /// Return a readable name for this pattern. This name should only be used for
139  /// debugging purposes, and may be empty.
140  StringRef getDebugName() const { return debugName; }
141 
142  /// Set the human readable debug name used for this pattern. This name will
143  /// only be used for debugging purposes.
144  void setDebugName(StringRef name) { debugName = name; }
145 
146  /// Return the set of debug labels attached to this pattern.
147  ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
148 
149  /// Add the provided debug labels to this pattern.
151  debugLabels.append(labels.begin(), labels.end());
152  }
153  void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
154 
155 protected:
156  /// This class acts as a special tag that makes the desire to match "any"
157  /// operation type explicit. This helps to avoid unnecessary usages of this
158  /// feature, and ensures that the user is making a conscious decision.
159  struct MatchAnyOpTypeTag {};
160  /// This class acts as a special tag that makes the desire to match any
161  /// operation that implements a given interface explicit. This helps to avoid
162  /// unnecessary usages of this feature, and ensures that the user is making a
163  /// conscious decision.
165  /// This class acts as a special tag that makes the desire to match any
166  /// operation that implements a given trait explicit. This helps to avoid
167  /// unnecessary usages of this feature, and ensures that the user is making a
168  /// conscious decision.
170 
171  /// Construct a pattern with a certain benefit that matches the operation
172  /// with the given root name.
173  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
174  ArrayRef<StringRef> generatedNames = {});
175  /// Construct a pattern that may match any operation type. `generatedNames`
176  /// contains the names of operations that may be generated during a successful
177  /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
178  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
179  /// always be supplied here.
180  Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
181  ArrayRef<StringRef> generatedNames = {});
182  /// Construct a pattern that may match any operation that implements the
183  /// interface defined by the provided `interfaceID`. `generatedNames` contains
184  /// the names of operations that may be generated during a successful rewrite.
185  /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
186  /// interface" behavior is what the user actually desired,
187  /// `MatchInterfaceOpTypeTag()` should always be supplied here.
188  Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
189  PatternBenefit benefit, MLIRContext *context,
190  ArrayRef<StringRef> generatedNames = {});
191  /// Construct a pattern that may match any operation that implements the
192  /// trait defined by the provided `traitID`. `generatedNames` contains the
193  /// names of operations that may be generated during a successful rewrite.
194  /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
195  /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
196  /// always be supplied here.
197  Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
198  MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
199 
200  /// Set the flag detailing if this pattern has bounded rewrite recursion or
201  /// not.
202  void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
203  contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
204  }
205 
206 private:
207  Pattern(const void *rootValue, RootKind rootKind,
208  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
209  MLIRContext *context);
210 
211  /// The value used to match the root operation of the pattern.
212  const void *rootValue;
213  RootKind rootKind;
214 
215  /// The expected benefit of matching this pattern.
216  const PatternBenefit benefit;
217 
218  /// The context this pattern was created from, and a boolean flag indicating
219  /// whether this pattern has bounded recursion or not.
220  llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
221 
222  /// A list of the potential operations that may be generated when rewriting
223  /// an op with this pattern.
224  SmallVector<OperationName, 2> generatedOps;
225 
226  /// A readable name for this pattern. May be empty.
227  StringRef debugName;
228 
229  /// The set of debug labels attached to this pattern.
230  SmallVector<StringRef, 0> debugLabels;
231 };
232 
233 //===----------------------------------------------------------------------===//
234 // RewritePattern
235 //===----------------------------------------------------------------------===//
236 
237 /// RewritePattern is the common base class for all DAG to DAG replacements.
238 class RewritePattern : public Pattern {
239 public:
240  virtual ~RewritePattern() = default;
241 
242  /// Attempt to match against code rooted at the specified operation,
243  /// which is the same operation code as getRootKind(). If successful, perform
244  /// the rewrite.
245  ///
246  /// Note: Implementations must modify the IR if and only if the function
247  /// returns "success".
248  virtual LogicalResult matchAndRewrite(Operation *op,
249  PatternRewriter &rewriter) const = 0;
250 
251  /// This method provides a convenient interface for creating and initializing
252  /// derived rewrite patterns of the given type `T`.
253  template <typename T, typename... Args>
254  static std::unique_ptr<T> create(Args &&...args) {
255  std::unique_ptr<T> pattern =
256  std::make_unique<T>(std::forward<Args>(args)...);
257  initializePattern<T>(*pattern);
258 
259  // Set a default debug name if one wasn't provided.
260  if (pattern->getDebugName().empty())
261  pattern->setDebugName(llvm::getTypeName<T>());
262  return pattern;
263  }
264 
265 protected:
266  /// Inherit the base constructors from `Pattern`.
267  using Pattern::Pattern;
268 
269 private:
270  /// Trait to check if T provides a `initialize` method.
271  template <typename T, typename... Args>
272  using has_initialize = decltype(std::declval<T>().initialize());
273  template <typename T>
274  using detect_has_initialize = llvm::is_detected<has_initialize, T>;
275 
276  /// Initialize the derived pattern by calling its `initialize` method if
277  /// available.
278  template <typename T>
279  static void initializePattern(T &pattern) {
280  if constexpr (detect_has_initialize<T>::value)
281  pattern.initialize();
282  }
283 
284  /// An anchor for the virtual table.
285  virtual void anchor();
286 };
287 
288 namespace detail {
289 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
290 /// allows for matching and rewriting against an instance of a derived operation
291 /// class or Interface.
292 template <typename SourceOp>
294  using RewritePattern::RewritePattern;
295 
296  /// Wrapper around the RewritePattern method that passes the derived op type.
297  LogicalResult matchAndRewrite(Operation *op,
298  PatternRewriter &rewriter) const final {
299  return matchAndRewrite(cast<SourceOp>(op), rewriter);
300  }
301 
302  /// Method that operates on the SourceOp type. Must be overridden by the
303  /// derived pattern class.
304  virtual LogicalResult matchAndRewrite(SourceOp op,
305  PatternRewriter &rewriter) const = 0;
306 };
307 } // namespace detail
308 
309 /// OpRewritePattern is a wrapper around RewritePattern that allows for
310 /// matching and rewriting against an instance of a derived operation class as
311 /// opposed to a raw Operation.
312 template <typename SourceOp>
315  /// Type alias to allow derived classes to inherit constructors with
316  /// `using Base::Base;`.
318 
319  /// Patterns must specify the root operation name they match against, and can
320  /// also specify the benefit of the pattern matching and a list of generated
321  /// ops.
323  ArrayRef<StringRef> generatedNames = {})
325  SourceOp::getOperationName(), benefit, context, generatedNames) {}
326 };
327 
328 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
329 /// matching and rewriting against an instance of an operation interface instead
330 /// of a raw Operation.
331 template <typename SourceOp>
334  /// Type alias to allow derived classes to inherit constructors with
335  /// `using Base::Base;`.
337 
339  : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
340  Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
341  benefit, context) {}
342 };
343 
344 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
345 /// matching and rewriting against instances of an operation that possess a
346 /// given trait.
347 template <template <typename> class TraitType>
349 public:
350  /// Type alias to allow derived classes to inherit constructors with
351  /// `using Base::Base;`.
353 
356  benefit, context) {}
357 };
358 
359 //===----------------------------------------------------------------------===//
360 // RewriterBase
361 //===----------------------------------------------------------------------===//
362 
363 /// This class coordinates the application of a rewrite on a set of IR,
364 /// providing a way for clients to track mutations and create new operations.
365 /// This class serves as a common API for IR mutation between pattern rewrites
366 /// and non-pattern rewrites, and facilitates the development of shared
367 /// IR transformation utilities.
368 class RewriterBase : public OpBuilder {
369 public:
370  struct Listener : public OpBuilder::Listener {
372  : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
373 
374  /// Notify the listener that the specified block is about to be erased.
375  /// At this point, the block has zero uses.
376  virtual void notifyBlockErased(Block *block) {}
377 
378  /// Notify the listener that the specified operation was modified in-place.
379  virtual void notifyOperationModified(Operation *op) {}
380 
381  /// Notify the listener that all uses of the specified operation's results
382  /// are about to be replaced with the results of another operation. This is
383  /// called before the uses of the old operation have been changed.
384  ///
385  /// By default, this function calls the "operation replaced with values"
386  /// notification.
388  Operation *replacement) {
389  notifyOperationReplaced(op, replacement->getResults());
390  }
391 
392  /// Notify the listener that all uses of the specified operation's results
393  /// are about to be replaced with the a range of values, potentially
394  /// produced by other operations. This is called before the uses of the
395  /// operation have been changed.
397  ValueRange replacement) {}
398 
399  /// Notify the listener that the specified operation is about to be erased.
400  /// At this point, the operation has zero uses.
401  ///
402  /// Note: This notification is not triggered when unlinking an operation.
403  virtual void notifyOperationErased(Operation *op) {}
404 
405  /// Notify the listener that the specified pattern is about to be applied
406  /// at the specified root operation.
407  virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
408 
409  /// Notify the listener that a pattern application finished with the
410  /// specified status. "success" indicates that the pattern was applied
411  /// successfully. "failure" indicates that the pattern could not be
412  /// applied. The pattern may have communicated the reason for the failure
413  /// with `notifyMatchFailure`.
414  virtual void notifyPatternEnd(const Pattern &pattern,
415  LogicalResult status) {}
416 
417  /// Notify the listener that the pattern failed to match, and provide a
418  /// callback to populate a diagnostic with the reason why the failure
419  /// occurred. This method allows for derived listeners to optionally hook
420  /// into the reason why a rewrite failed, and display it to users.
421  virtual void
423  function_ref<void(Diagnostic &)> reasonCallback) {}
424 
425  static bool classof(const OpBuilder::Listener *base);
426  };
427 
428  /// A listener that forwards all notifications to another listener. This
429  /// struct can be used as a base to create listener chains, so that multiple
430  /// listeners can be notified of IR changes.
433  : listener(listener),
434  rewriteListener(
435  dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
436 
437  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
438  if (listener)
439  listener->notifyOperationInserted(op, previous);
440  }
441  void notifyBlockInserted(Block *block, Region *previous,
442  Region::iterator previousIt) override {
443  if (listener)
444  listener->notifyBlockInserted(block, previous, previousIt);
445  }
446  void notifyBlockErased(Block *block) override {
447  if (rewriteListener)
448  rewriteListener->notifyBlockErased(block);
449  }
450  void notifyOperationModified(Operation *op) override {
451  if (rewriteListener)
452  rewriteListener->notifyOperationModified(op);
453  }
454  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
455  if (rewriteListener)
456  rewriteListener->notifyOperationReplaced(op, newOp);
457  }
459  ValueRange replacement) override {
460  if (rewriteListener)
461  rewriteListener->notifyOperationReplaced(op, replacement);
462  }
463  void notifyOperationErased(Operation *op) override {
464  if (rewriteListener)
465  rewriteListener->notifyOperationErased(op);
466  }
467  void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
468  if (rewriteListener)
469  rewriteListener->notifyPatternBegin(pattern, op);
470  }
471  void notifyPatternEnd(const Pattern &pattern,
472  LogicalResult status) override {
473  if (rewriteListener)
474  rewriteListener->notifyPatternEnd(pattern, status);
475  }
477  Location loc,
478  function_ref<void(Diagnostic &)> reasonCallback) override {
479  if (rewriteListener)
480  rewriteListener->notifyMatchFailure(loc, reasonCallback);
481  }
482 
483  private:
484  OpBuilder::Listener *listener;
485  RewriterBase::Listener *rewriteListener;
486  };
487 
488  /// A listener that logs notification events to llvm::dbgs() before
489  /// forwarding to the base listener.
491  PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName)
492  : RewriterBase::ForwardingListener(listener), patternName(patternName) {
493  }
494 
495  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
496  void notifyOperationModified(Operation *op) override;
497  void notifyOperationReplaced(Operation *op, Operation *newOp) override;
499  ValueRange replacement) override;
500  void notifyOperationErased(Operation *op) override;
501  void notifyPatternBegin(const Pattern &pattern, Operation *op) override;
502 
503  private:
504  StringRef patternName;
505  };
506 
507  /// Move the blocks that belong to "region" before the given position in
508  /// another region "parent". The two regions must be different. The caller
509  /// is responsible for creating or updating the operation transferring flow
510  /// of control to the region and passing it the correct block arguments.
511  void inlineRegionBefore(Region &region, Region &parent,
512  Region::iterator before);
513  void inlineRegionBefore(Region &region, Block *before);
514 
515  /// Replace the results of the given (original) operation with the specified
516  /// list of values (replacements). The result types of the given op and the
517  /// replacements must match. The original op is erased.
518  virtual void replaceOp(Operation *op, ValueRange newValues);
519 
520  /// Replace the results of the given (original) operation with the specified
521  /// new op (replacement). The result types of the two ops must match. The
522  /// original op is erased.
523  virtual void replaceOp(Operation *op, Operation *newOp);
524 
525  /// Replace the results of the given (original) op with a new op that is
526  /// created without verification (replacement). The result values of the two
527  /// ops must match. The original op is erased.
528  template <typename OpTy, typename... Args>
529  OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
530  auto builder = static_cast<OpBuilder *>(this);
531  auto newOp =
532  OpTy::create(*builder, op->getLoc(), std::forward<Args>(args)...);
533  replaceOp(op, newOp.getOperation());
534  return newOp;
535  }
536 
537  /// This method erases an operation that is known to have no uses.
538  ///
539  /// If the current insertion point is before the erased operation, it is
540  /// adjusted to the following operation (or the end of the block). If the
541  /// current insertion point is within the erased operation, the insertion
542  /// point is left in an invalid state.
543  virtual void eraseOp(Operation *op);
544 
545  /// This method erases all operations in a block.
546  virtual void eraseBlock(Block *block);
547 
548  /// Inline the operations of block 'source' into block 'dest' before the given
549  /// position. The source block will be deleted and must have no uses.
550  /// 'argValues' is used to replace the block arguments of 'source'.
551  ///
552  /// If the source block is inserted at the end of the dest block, the dest
553  /// block must have no successors. Similarly, if the source block is inserted
554  /// somewhere in the middle (or beginning) of the dest block, the source block
555  /// must have no successors. Otherwise, the resulting IR would have
556  /// unreachable operations.
557  ///
558  /// If the insertion point is within the source block, it is adjusted to the
559  /// destination block.
560  virtual void inlineBlockBefore(Block *source, Block *dest,
561  Block::iterator before,
562  ValueRange argValues = {});
563 
564  /// Inline the operations of block 'source' before the operation 'op'. The
565  /// source block will be deleted and must have no uses. 'argValues' is used to
566  /// replace the block arguments of 'source'
567  ///
568  /// The source block must have no successors. Otherwise, the resulting IR
569  /// would have unreachable operations.
570  ///
571  /// If the insertion point is within the source block, it is adjusted to the
572  /// destination block.
573  void inlineBlockBefore(Block *source, Operation *op,
574  ValueRange argValues = {});
575 
576  /// Inline the operations of block 'source' into the end of block 'dest'. The
577  /// source block will be deleted and must have no uses. 'argValues' is used to
578  /// replace the block arguments of 'source'
579  ///
580  /// The dest block must have no successors. Otherwise, the resulting IR would
581  /// have unreachable operation.
582  ///
583  /// If the insertion point is within the source block, it is adjusted to the
584  /// destination block.
585  void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {});
586 
587  /// Split the operations starting at "before" (inclusive) out of the given
588  /// block into a new block, and return it.
589  Block *splitBlock(Block *block, Block::iterator before);
590 
591  /// Unlink this operation from its current block and insert it right before
592  /// `existingOp` which may be in the same or another block in the same
593  /// function.
594  void moveOpBefore(Operation *op, Operation *existingOp);
595 
596  /// Unlink this operation from its current block and insert it right before
597  /// `iterator` in the specified block.
598  void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
599 
600  /// Unlink this operation from its current block and insert it right after
601  /// `existingOp` which may be in the same or another block in the same
602  /// function.
603  void moveOpAfter(Operation *op, Operation *existingOp);
604 
605  /// Unlink this operation from its current block and insert it right after
606  /// `iterator` in the specified block.
607  void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
608 
609  /// Unlink this block and insert it right before `existingBlock`.
610  void moveBlockBefore(Block *block, Block *anotherBlock);
611 
612  /// Unlink this block and insert it right before the location that the given
613  /// iterator points to in the given region.
614  void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
615 
616  /// This method is used to notify the rewriter that an in-place operation
617  /// modification is about to happen. A call to this function *must* be
618  /// followed by a call to either `finalizeOpModification` or
619  /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
620  /// a new operation and removing the old one) but also often allows simpler
621  /// code in the client.
622  virtual void startOpModification(Operation *op) {}
623 
624  /// This method is used to signal the end of an in-place modification of the
625  /// given operation. This can only be called on operations that were provided
626  /// to a call to `startOpModification`.
627  virtual void finalizeOpModification(Operation *op);
628 
629  /// This method cancels a pending in-place modification. This can only be
630  /// called on operations that were provided to a call to
631  /// `startOpModification`.
632  virtual void cancelOpModification(Operation *op) {}
633 
634  /// This method is a utility wrapper around an in-place modification of an
635  /// operation. It wraps calls to `startOpModification` and
636  /// `finalizeOpModification` around the given callable.
637  template <typename CallableT>
638  void modifyOpInPlace(Operation *root, CallableT &&callable) {
639  startOpModification(root);
640  callable();
642  }
643 
644  /// Find uses of `from` and replace them with `to`. Also notify the listener
645  /// about every in-place op modification (for every use that was replaced).
646  virtual void replaceAllUsesWith(Value from, Value to) {
647  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
648  Operation *op = operand.getOwner();
649  modifyOpInPlace(op, [&]() { operand.set(to); });
650  }
651  }
652  void replaceAllUsesWith(Block *from, Block *to) {
653  for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
654  Operation *op = operand.getOwner();
655  modifyOpInPlace(op, [&]() { operand.set(to); });
656  }
657  }
659  assert(from.size() == to.size() && "incorrect number of replacements");
660  for (auto it : llvm::zip(from, to))
661  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
662  }
663 
664  /// Find uses of `from` and replace them with `to`. Also notify the listener
665  /// about every in-place op modification (for every use that was replaced)
666  /// and that the `from` operation is about to be replaced.
667  ///
668  /// Note: This function cannot be called `replaceAllUsesWith` because the
669  /// overload resolution, when called with an op that can be implicitly
670  /// converted to a Value, would be ambiguous.
671  void replaceAllOpUsesWith(Operation *from, ValueRange to);
672  void replaceAllOpUsesWith(Operation *from, Operation *to);
673 
674  /// Find uses of `from` and replace them with `to` if the `functor` returns
675  /// true. Also notify the listener about every in-place op modification (for
676  /// every use that was replaced). The optional `allUsesReplaced` flag is set
677  /// to "true" if all uses were replaced.
678  void replaceUsesWithIf(Value from, Value to,
679  function_ref<bool(OpOperand &)> functor,
680  bool *allUsesReplaced = nullptr);
682  function_ref<bool(OpOperand &)> functor,
683  bool *allUsesReplaced = nullptr);
684  // Note: This function cannot be called `replaceOpUsesWithIf` because the
685  // overload resolution, when called with an op that can be implicitly
686  // converted to a Value, would be ambiguous.
688  function_ref<bool(OpOperand &)> functor,
689  bool *allUsesReplaced = nullptr) {
690  replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
691  }
692 
693  /// Find uses of `from` within `block` and replace them with `to`. Also notify
694  /// the listener about every in-place op modification (for every use that was
695  /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
696  /// uses were replaced.
698  Block *block, bool *allUsesReplaced = nullptr) {
700  op, newValues,
701  [block](OpOperand &use) {
702  return block->getParentOp()->isProperAncestor(use.getOwner());
703  },
704  allUsesReplaced);
705  }
706 
707  /// Find uses of `from` and replace them with `to` except if the user is
708  /// `exceptedUser`. Also notify the listener about every in-place op
709  /// modification (for every use that was replaced).
710  void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
711  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
712  Operation *user = use.getOwner();
713  return user != exceptedUser;
714  });
715  }
716  void replaceAllUsesExcept(Value from, Value to,
717  const SmallPtrSetImpl<Operation *> &preservedUsers);
718 
719  /// Used to notify the listener that the IR failed to be rewritten because of
720  /// a match failure, and provide a callback to populate a diagnostic with the
721  /// reason why the failure occurred. This method allows for derived rewriters
722  /// to optionally hook into the reason why a rewrite failed, and display it to
723  /// users.
724  template <typename CallbackT>
725  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
726  notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
727  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
728  rewriteListener->notifyMatchFailure(
729  loc, function_ref<void(Diagnostic &)>(reasonCallback));
730  return failure();
731  }
732  template <typename CallbackT>
733  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
734  notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
735  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
736  rewriteListener->notifyMatchFailure(
737  op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
738  return failure();
739  }
740  template <typename ArgT>
741  LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
742  return notifyMatchFailure(std::forward<ArgT>(arg),
743  [&](Diagnostic &diag) { diag << msg; });
744  }
745  template <typename ArgT>
746  LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
747  return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
748  }
749 
750 protected:
751  /// Initialize the builder.
752  explicit RewriterBase(MLIRContext *ctx,
753  OpBuilder::Listener *listener = nullptr)
754  : OpBuilder(ctx, listener) {}
755  explicit RewriterBase(const OpBuilder &otherBuilder)
756  : OpBuilder(otherBuilder) {}
758  : OpBuilder(op, listener) {}
759  virtual ~RewriterBase();
760 
761 private:
762  void operator=(const RewriterBase &) = delete;
763  RewriterBase(const RewriterBase &) = delete;
764 };
765 
766 //===----------------------------------------------------------------------===//
767 // IRRewriter
768 //===----------------------------------------------------------------------===//
769 
770 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
771 /// providing a way to keep track of the mutations made to the IR. This class
772 /// should only be used in situations where another `RewriterBase` instance,
773 /// such as a `PatternRewriter`, is not available.
774 class IRRewriter : public RewriterBase {
775 public:
777  : RewriterBase(ctx, listener) {}
778  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
780  : RewriterBase(op, listener) {}
781 };
782 
783 //===----------------------------------------------------------------------===//
784 // PatternRewriter
785 //===----------------------------------------------------------------------===//
786 
787 /// A special type of `RewriterBase` that coordinates the application of a
788 /// rewrite pattern on the current IR being matched, providing a way to keep
789 /// track of any mutations made. This class should be used to perform all
790 /// necessary IR mutations within a rewrite pattern, as the pattern driver may
791 /// be tracking various state that would be invalidated when a mutation takes
792 /// place.
794 public:
795  explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
797 
798  /// A hook used to indicate if the pattern rewriter can recover from failure
799  /// during the rewrite stage of a pattern. For example, if the pattern
800  /// rewriter supports rollback, it may progress smoothly even if IR was
801  /// changed during the rewrite.
802  virtual bool canRecoverFromRewriteFailure() const { return false; }
803 };
804 
805 } // namespace mlir
806 
807 // Optionally expose PDL pattern matching methods.
808 #include "PDLPatternMatch.h.inc"
809 
810 namespace mlir {
811 
812 //===----------------------------------------------------------------------===//
813 // RewritePatternSet
814 //===----------------------------------------------------------------------===//
815 
817  using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
818 
819 public:
820  RewritePatternSet(MLIRContext *context) : context(context) {}
821 
822  /// Construct a RewritePatternSet populated with the given pattern.
824  std::unique_ptr<RewritePattern> pattern)
825  : context(context) {
826  nativePatterns.emplace_back(std::move(pattern));
827  }
828  RewritePatternSet(PDLPatternModule &&pattern)
829  : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
830 
831  MLIRContext *getContext() const { return context; }
832 
833  /// Return the native patterns held in this list.
834  NativePatternListT &getNativePatterns() { return nativePatterns; }
835 
836  /// Return the PDL patterns held in this list.
837  PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
838 
839  /// Clear out all of the held patterns in this list.
840  void clear() {
841  nativePatterns.clear();
842  pdlPatterns.clear();
843  }
844 
845  //===--------------------------------------------------------------------===//
846  // 'add' methods for adding patterns to the set.
847  //===--------------------------------------------------------------------===//
848 
849  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
850  /// the given arguments. Return a reference to `this` for chaining insertions.
851  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
852  template <typename... Ts, typename ConstructorArg,
853  typename... ConstructorArgs,
854  typename = std::enable_if_t<sizeof...(Ts) != 0>>
855  RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
856  // The following expands a call to emplace_back for each of the pattern
857  // types 'Ts'.
858  (addImpl<Ts>(/*debugLabels=*/{}, std::forward<ConstructorArg>(arg),
859  std::forward<ConstructorArgs>(args)...),
860  ...);
861  return *this;
862  }
863  /// An overload of the above `add` method that allows for attaching a set
864  /// of debug labels to the attached patterns. This is useful for labeling
865  /// groups of patterns that may be shared between multiple different
866  /// passes/users.
867  template <typename... Ts, typename ConstructorArg,
868  typename... ConstructorArgs,
869  typename = std::enable_if_t<sizeof...(Ts) != 0>>
871  ConstructorArg &&arg,
872  ConstructorArgs &&...args) {
873  // The following expands a call to emplace_back for each of the pattern
874  // types 'Ts'.
875  (addImpl<Ts>(debugLabels, arg, args...), ...);
876  return *this;
877  }
878 
879  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
880  /// `this` for chaining insertions.
881  template <typename... Ts>
883  (addImpl<Ts>(), ...);
884  return *this;
885  }
886 
887  /// Add the given native pattern to the pattern list. Return a reference to
888  /// `this` for chaining insertions.
889  RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
890  nativePatterns.emplace_back(std::move(pattern));
891  return *this;
892  }
893 
894  /// Add the given PDL pattern to the pattern list. Return a reference to
895  /// `this` for chaining insertions.
896  RewritePatternSet &add(PDLPatternModule &&pattern) {
897  pdlPatterns.mergeIn(std::move(pattern));
898  return *this;
899  }
900 
901  // Add a matchAndRewrite style pattern represented as a C function pointer.
902  template <typename OpType>
904  add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
905  PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
906  struct FnPattern final : public OpRewritePattern<OpType> {
907  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
908  MLIRContext *context, PatternBenefit benefit,
909  ArrayRef<StringRef> generatedNames)
910  : OpRewritePattern<OpType>(context, benefit, generatedNames),
911  implFn(implFn) {}
912 
913  LogicalResult matchAndRewrite(OpType op,
914  PatternRewriter &rewriter) const override {
915  return implFn(op, rewriter);
916  }
917 
918  private:
919  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
920  };
921  add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
922  generatedNames));
923  return *this;
924  }
925 
926  //===--------------------------------------------------------------------===//
927  // Pattern Insertion
928  //===--------------------------------------------------------------------===//
929 
930  // TODO: These are soft deprecated in favor of the 'add' methods above.
931 
932  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
933  /// the given arguments. Return a reference to `this` for chaining insertions.
934  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
935  template <typename... Ts, typename ConstructorArg,
936  typename... ConstructorArgs,
937  typename = std::enable_if_t<sizeof...(Ts) != 0>>
938  RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
939  // The following expands a call to emplace_back for each of the pattern
940  // types 'Ts'.
941  (addImpl<Ts>(/*debugLabels=*/{}, arg, args...), ...);
942  return *this;
943  }
944 
945  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
946  /// `this` for chaining insertions.
947  template <typename... Ts>
949  (addImpl<Ts>(), ...);
950  return *this;
951  }
952 
953  /// Add the given native pattern to the pattern list. Return a reference to
954  /// `this` for chaining insertions.
955  RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
956  nativePatterns.emplace_back(std::move(pattern));
957  return *this;
958  }
959 
960  /// Add the given PDL pattern to the pattern list. Return a reference to
961  /// `this` for chaining insertions.
962  RewritePatternSet &insert(PDLPatternModule &&pattern) {
963  pdlPatterns.mergeIn(std::move(pattern));
964  return *this;
965  }
966 
967  // Add a matchAndRewrite style pattern represented as a C function pointer.
968  template <typename OpType>
970  insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
971  struct FnPattern final : public OpRewritePattern<OpType> {
972  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
973  MLIRContext *context)
974  : OpRewritePattern<OpType>(context), implFn(implFn) {
975  this->setDebugName(llvm::getTypeName<FnPattern>());
976  }
977 
978  LogicalResult matchAndRewrite(OpType op,
979  PatternRewriter &rewriter) const override {
980  return implFn(op, rewriter);
981  }
982 
983  private:
984  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
985  };
986  add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
987  return *this;
988  }
989 
990 private:
991  /// Add an instance of the pattern type 'T'. Return a reference to `this` for
992  /// chaining insertions.
993  template <typename T, typename... Args>
994  std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
995  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
996  std::unique_ptr<T> pattern =
997  RewritePattern::create<T>(std::forward<Args>(args)...);
998  pattern->addDebugLabels(debugLabels);
999  nativePatterns.emplace_back(std::move(pattern));
1000  }
1001 
1002  template <typename T, typename... Args>
1003  std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
1004  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1005  // TODO: Add the provided labels to the PDL pattern when PDL supports
1006  // labels.
1007  pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1008  }
1009 
1010  MLIRContext *const context;
1011  NativePatternListT nativePatterns;
1012 
1013  // Patterns expressed with PDL. This will compile to a stub class when PDL is
1014  // not enabled.
1015  PDLPatternModule pdlPatterns;
1016 };
1017 
1018 } // namespace mlir
1019 
1020 #endif // MLIR_IR_PATTERNMATCH_H
static std::string diag(const llvm::Value &value)
A block operand represents an operand that holds a reference to a Block, e.g.
Definition: BlockSupport.h:30
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: UseDefLists.h:253
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:776
IRRewriter(const OpBuilder &builder)
Definition: PatternMatch.h:778
IRRewriter(Operation *op, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:779
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a saved insertion point.
Definition: Builders.h:327
This class helps build Operations.
Definition: Builders.h:207
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:616
This class represents an operand of an operation.
Definition: Value.h:257
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:348
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:354
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
PatternBenefit & operator=(const PatternBenefit &)=default
bool operator<(const PatternBenefit &rhs) const
Definition: PatternMatch.h:54
bool operator==(const PatternBenefit &rhs) const
Definition: PatternMatch.h:50
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
bool operator>=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:59
bool operator<=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:58
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
bool operator!=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:53
PatternBenefit()=default
bool operator>(const PatternBenefit &rhs) const
Definition: PatternMatch.h:57
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
PatternRewriter(MLIRContext *ctx)
Definition: PatternMatch.h:795
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:802
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Definition: PatternMatch.h:103
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:134
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
Definition: PatternMatch.h:147
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
Definition: PatternMatch.h:202
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:123
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
Definition: PatternMatch.h:112
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
Definition: PatternMatch.h:144
void addDebugLabels(StringRef label)
Definition: PatternMatch.h:153
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
Definition: PatternMatch.h:150
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:140
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType::iterator iterator
Definition: Region.h:52
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
Definition: PatternMatch.h:896
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Definition: PatternMatch.h:904
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
Definition: PatternMatch.h:834
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
Definition: PatternMatch.h:889
RewritePatternSet(PDLPatternModule &&pattern)
Definition: PatternMatch.h:828
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:938
MLIRContext * getContext() const
Definition: PatternMatch.h:831
void clear()
Clear out all of the held patterns in this list.
Definition: PatternMatch.h:840
RewritePatternSet(MLIRContext *context)
Definition: PatternMatch.h:820
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
Definition: PatternMatch.h:882
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
Definition: PatternMatch.h:823
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
Definition: PatternMatch.h:962
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
Definition: PatternMatch.h:955
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:837
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
Definition: PatternMatch.h:948
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
Definition: PatternMatch.h:870
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
Definition: PatternMatch.h:970
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
Definition: PatternMatch.h:254
virtual ~RewritePattern()=default
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const =0
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Definition: PatternMatch.h:734
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Definition: PatternMatch.h:687
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
Definition: PatternMatch.h:741
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
RewriterBase(const OpBuilder &otherBuilder)
Definition: PatternMatch.h:755
void replaceAllUsesWith(ValueRange from, ValueRange to)
Definition: PatternMatch.h:658
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
RewriterBase(Operation *op, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:757
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:632
void replaceAllUsesWith(Block *from, Block *to)
Definition: PatternMatch.h:652
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:710
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Definition: PatternMatch.h:752
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
Definition: PatternMatch.h:746
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Definition: PatternMatch.h:697
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:622
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:135
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Base class for listeners.
Definition: Builders.h:264
Kind
The kind of listener.
Definition: Builders.h:266
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:285
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:308
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:298
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:333
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:338
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:159
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:431
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:476
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Definition: PatternMatch.h:437
void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:471
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:450
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:467
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:454
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:463
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:446
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:432
void notifyOperationReplaced(Operation *op, ValueRange replacement) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:458
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notify the listener that the specified block was inserted.
Definition: PatternMatch.h:441
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:422
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:379
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:403
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:387
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:376
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:414
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:407
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:396
static bool classof(const OpBuilder::Listener *base)
A listener that logs notification events to llvm::dbgs() before forwarding to the base listener.
Definition: PatternMatch.h:490
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName)
Definition: PatternMatch.h:491
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
Definition: PatternMatch.h:293
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const =0
Method that operates on the SourceOp type.