MLIR  22.0.0git
PatternMatch.h
Go to the documentation of this file.
1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
16 #include <optional>
17 
19 namespace mlir {
20 
21 class PatternRewriter;
22 
23 //===----------------------------------------------------------------------===//
24 // PatternBenefit class
25 //===----------------------------------------------------------------------===//
26 
27 /// This class represents the benefit of a pattern match in a unitless scheme
28 /// that ranges from 0 (very little benefit) to 65K. The most common unit to
29 /// use here is the "number of operations matched" by the pattern.
30 ///
31 /// This also has a sentinel representation that can be used for patterns that
32 /// fail to match.
33 ///
35  enum { ImpossibleToMatchSentinel = 65535 };
36 
37 public:
38  PatternBenefit() = default;
39  PatternBenefit(unsigned benefit);
40  PatternBenefit(const PatternBenefit &) = default;
42 
44  bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
45 
46  /// If the corresponding pattern can match, return its benefit. If the
47  // corresponding pattern isImpossibleToMatch() then this aborts.
48  unsigned short getBenefit() const;
49 
50  bool operator==(const PatternBenefit &rhs) const {
51  return representation == rhs.representation;
52  }
53  bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
54  bool operator<(const PatternBenefit &rhs) const {
55  return representation < rhs.representation;
56  }
57  bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
58  bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
59  bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
60 
61 private:
62  unsigned short representation{ImpossibleToMatchSentinel};
63 };
64 
65 //===----------------------------------------------------------------------===//
66 // Pattern
67 //===----------------------------------------------------------------------===//
68 
69 /// This class contains all of the data related to a pattern, but does not
70 /// contain any methods or logic for the actual matching. This class is solely
71 /// used to interface with the metadata of a pattern, such as the benefit or
72 /// root operation.
73 class Pattern {
74  /// This enum represents the kind of value used to select the root operations
75  /// that match this pattern.
76  enum class RootKind {
77  /// The pattern root matches "any" operation.
78  Any,
79  /// The pattern root is matched using a concrete operation name.
81  /// The pattern root is matched using an interface ID.
82  InterfaceID,
83  /// The patter root is matched using a trait ID.
84  TraitID
85  };
86 
87 public:
88  /// Return a list of operations that may be generated when rewriting an
89  /// operation instance with this pattern.
90  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
91 
92  /// Return the root node that this pattern matches. Patterns that can match
93  /// multiple root types return std::nullopt.
94  std::optional<OperationName> getRootKind() const {
95  if (rootKind == RootKind::OperationName)
96  return OperationName::getFromOpaquePointer(rootValue);
97  return std::nullopt;
98  }
99 
100  /// Return the interface ID used to match the root operation of this pattern.
101  /// If the pattern does not use an interface ID for deciding the root match,
102  /// this returns std::nullopt.
103  std::optional<TypeID> getRootInterfaceID() const {
104  if (rootKind == RootKind::InterfaceID)
105  return TypeID::getFromOpaquePointer(rootValue);
106  return std::nullopt;
107  }
108 
109  /// Return the trait ID used to match the root operation of this pattern.
110  /// If the pattern does not use a trait ID for deciding the root match, this
111  /// returns std::nullopt.
112  std::optional<TypeID> getRootTraitID() const {
113  if (rootKind == RootKind::TraitID)
114  return TypeID::getFromOpaquePointer(rootValue);
115  return std::nullopt;
116  }
117 
118  /// Return the benefit (the inverse of "cost") of matching this pattern. The
119  /// benefit of a Pattern is always static - rewrites that may have dynamic
120  /// benefit can be instantiated multiple times (different Pattern instances)
121  /// for each benefit that they may return, and be guarded by different match
122  /// condition predicates.
123  PatternBenefit getBenefit() const { return benefit; }
124 
125  /// Returns true if this pattern is known to result in recursive application,
126  /// i.e. this pattern may generate IR that also matches this pattern, but is
127  /// known to bound the recursion. This signals to a rewrite driver that it is
128  /// safe to apply this pattern recursively to generated IR.
130  return contextAndHasBoundedRecursion.getInt();
131  }
132 
133  /// Return the MLIRContext used to create this pattern.
135  return contextAndHasBoundedRecursion.getPointer();
136  }
137 
138  /// Return a readable name for this pattern. This name should only be used for
139  /// debugging purposes, and may be empty.
140  StringRef getDebugName() const { return debugName; }
141 
142  /// Set the human readable debug name used for this pattern. This name will
143  /// only be used for debugging purposes.
144  void setDebugName(StringRef name) { debugName = name; }
145 
146  /// Return the set of debug labels attached to this pattern.
147  ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
148 
149  /// Add the provided debug labels to this pattern.
151  debugLabels.append(labels.begin(), labels.end());
152  }
153  void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
154 
155 protected:
156  /// This class acts as a special tag that makes the desire to match "any"
157  /// operation type explicit. This helps to avoid unnecessary usages of this
158  /// feature, and ensures that the user is making a conscious decision.
159  struct MatchAnyOpTypeTag {};
160  /// This class acts as a special tag that makes the desire to match any
161  /// operation that implements a given interface explicit. This helps to avoid
162  /// unnecessary usages of this feature, and ensures that the user is making a
163  /// conscious decision.
165  /// This class acts as a special tag that makes the desire to match any
166  /// operation that implements a given trait explicit. This helps to avoid
167  /// unnecessary usages of this feature, and ensures that the user is making a
168  /// conscious decision.
170 
171  /// Construct a pattern with a certain benefit that matches the operation
172  /// with the given root name.
173  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
174  ArrayRef<StringRef> generatedNames = {});
175  /// Construct a pattern that may match any operation type. `generatedNames`
176  /// contains the names of operations that may be generated during a successful
177  /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
178  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
179  /// always be supplied here.
180  Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
181  ArrayRef<StringRef> generatedNames = {});
182  /// Construct a pattern that may match any operation that implements the
183  /// interface defined by the provided `interfaceID`. `generatedNames` contains
184  /// the names of operations that may be generated during a successful rewrite.
185  /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
186  /// interface" behavior is what the user actually desired,
187  /// `MatchInterfaceOpTypeTag()` should always be supplied here.
188  Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
189  PatternBenefit benefit, MLIRContext *context,
190  ArrayRef<StringRef> generatedNames = {});
191  /// Construct a pattern that may match any operation that implements the
192  /// trait defined by the provided `traitID`. `generatedNames` contains the
193  /// names of operations that may be generated during a successful rewrite.
194  /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
195  /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
196  /// always be supplied here.
197  Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
198  MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
199 
200  /// Set the flag detailing if this pattern has bounded rewrite recursion or
201  /// not.
202  void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
203  contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
204  }
205 
206 private:
207  Pattern(const void *rootValue, RootKind rootKind,
208  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
209  MLIRContext *context);
210 
211  /// The value used to match the root operation of the pattern.
212  const void *rootValue;
213  RootKind rootKind;
214 
215  /// The expected benefit of matching this pattern.
216  const PatternBenefit benefit;
217 
218  /// The context this pattern was created from, and a boolean flag indicating
219  /// whether this pattern has bounded recursion or not.
220  llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
221 
222  /// A list of the potential operations that may be generated when rewriting
223  /// an op with this pattern.
224  SmallVector<OperationName, 2> generatedOps;
225 
226  /// A readable name for this pattern. May be empty.
227  StringRef debugName;
228 
229  /// The set of debug labels attached to this pattern.
230  SmallVector<StringRef, 0> debugLabels;
231 };
232 
233 //===----------------------------------------------------------------------===//
234 // RewritePattern
235 //===----------------------------------------------------------------------===//
236 
237 /// RewritePattern is the common base class for all DAG to DAG replacements.
238 class RewritePattern : public Pattern {
239 public:
240  virtual ~RewritePattern() = default;
241 
242  /// Attempt to match against code rooted at the specified operation,
243  /// which is the same operation code as getRootKind(). If successful, perform
244  /// the rewrite.
245  ///
246  /// Note: Implementations must modify the IR if and only if the function
247  /// returns "success".
248  virtual LogicalResult matchAndRewrite(Operation *op,
249  PatternRewriter &rewriter) const = 0;
250 
251  /// This method provides a convenient interface for creating and initializing
252  /// derived rewrite patterns of the given type `T`.
253  template <typename T, typename... Args>
254  static std::unique_ptr<T> create(Args &&...args) {
255  std::unique_ptr<T> pattern =
256  std::make_unique<T>(std::forward<Args>(args)...);
257  initializePattern<T>(*pattern);
258 
259  // Set a default debug name if one wasn't provided.
260  if (pattern->getDebugName().empty())
261  pattern->setDebugName(llvm::getTypeName<T>());
262  return pattern;
263  }
264 
265 protected:
266  /// Inherit the base constructors from `Pattern`.
267  using Pattern::Pattern;
268 
269 private:
270  /// Trait to check if T provides a `initialize` method.
271  template <typename T, typename... Args>
272  using has_initialize = decltype(std::declval<T>().initialize());
273  template <typename T>
274  using detect_has_initialize = llvm::is_detected<has_initialize, T>;
275 
276  /// Initialize the derived pattern by calling its `initialize` method if
277  /// available.
278  template <typename T>
279  static void initializePattern(T &pattern) {
280  if constexpr (detect_has_initialize<T>::value)
281  pattern.initialize();
282  }
283 
284  /// An anchor for the virtual table.
285  virtual void anchor();
286 };
287 
288 namespace detail {
289 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
290 /// allows for matching and rewriting against an instance of a derived operation
291 /// class or Interface.
292 template <typename SourceOp>
294  using RewritePattern::RewritePattern;
295 
296  /// Wrapper around the RewritePattern method that passes the derived op type.
297  LogicalResult matchAndRewrite(Operation *op,
298  PatternRewriter &rewriter) const final {
299  return matchAndRewrite(cast<SourceOp>(op), rewriter);
300  }
301 
302  /// Method that operates on the SourceOp type. Must be overridden by the
303  /// derived pattern class.
304  virtual LogicalResult matchAndRewrite(SourceOp op,
305  PatternRewriter &rewriter) const = 0;
306 };
307 } // namespace detail
308 
309 /// OpRewritePattern is a wrapper around RewritePattern that allows for
310 /// matching and rewriting against an instance of a derived operation class as
311 /// opposed to a raw Operation.
312 template <typename SourceOp>
315 
316  /// Patterns must specify the root operation name they match against, and can
317  /// also specify the benefit of the pattern matching and a list of generated
318  /// ops.
320  ArrayRef<StringRef> generatedNames = {})
322  SourceOp::getOperationName(), benefit, context, generatedNames) {}
323 };
324 
325 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
326 /// matching and rewriting against an instance of an operation interface instead
327 /// of a raw Operation.
328 template <typename SourceOp>
331 
333  : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
334  Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
335  benefit, context) {}
336 };
337 
338 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
339 /// matching and rewriting against instances of an operation that possess a
340 /// given trait.
341 template <template <typename> class TraitType>
343 public:
346  benefit, context) {}
347 };
348 
349 //===----------------------------------------------------------------------===//
350 // RewriterBase
351 //===----------------------------------------------------------------------===//
352 
353 /// This class coordinates the application of a rewrite on a set of IR,
354 /// providing a way for clients to track mutations and create new operations.
355 /// This class serves as a common API for IR mutation between pattern rewrites
356 /// and non-pattern rewrites, and facilitates the development of shared
357 /// IR transformation utilities.
358 class RewriterBase : public OpBuilder {
359 public:
360  struct Listener : public OpBuilder::Listener {
362  : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
363 
364  /// Notify the listener that the specified block is about to be erased.
365  /// At this point, the block has zero uses.
366  virtual void notifyBlockErased(Block *block) {}
367 
368  /// Notify the listener that the specified operation was modified in-place.
369  virtual void notifyOperationModified(Operation *op) {}
370 
371  /// Notify the listener that all uses of the specified operation's results
372  /// are about to be replaced with the results of another operation. This is
373  /// called before the uses of the old operation have been changed.
374  ///
375  /// By default, this function calls the "operation replaced with values"
376  /// notification.
378  Operation *replacement) {
379  notifyOperationReplaced(op, replacement->getResults());
380  }
381 
382  /// Notify the listener that all uses of the specified operation's results
383  /// are about to be replaced with the a range of values, potentially
384  /// produced by other operations. This is called before the uses of the
385  /// operation have been changed.
387  ValueRange replacement) {}
388 
389  /// Notify the listener that the specified operation is about to be erased.
390  /// At this point, the operation has zero uses.
391  ///
392  /// Note: This notification is not triggered when unlinking an operation.
393  virtual void notifyOperationErased(Operation *op) {}
394 
395  /// Notify the listener that the specified pattern is about to be applied
396  /// at the specified root operation.
397  virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
398 
399  /// Notify the listener that a pattern application finished with the
400  /// specified status. "success" indicates that the pattern was applied
401  /// successfully. "failure" indicates that the pattern could not be
402  /// applied. The pattern may have communicated the reason for the failure
403  /// with `notifyMatchFailure`.
404  virtual void notifyPatternEnd(const Pattern &pattern,
405  LogicalResult status) {}
406 
407  /// Notify the listener that the pattern failed to match, and provide a
408  /// callback to populate a diagnostic with the reason why the failure
409  /// occurred. This method allows for derived listeners to optionally hook
410  /// into the reason why a rewrite failed, and display it to users.
411  virtual void
413  function_ref<void(Diagnostic &)> reasonCallback) {}
414 
415  static bool classof(const OpBuilder::Listener *base);
416  };
417 
418  /// A listener that forwards all notifications to another listener. This
419  /// struct can be used as a base to create listener chains, so that multiple
420  /// listeners can be notified of IR changes.
423  : listener(listener),
424  rewriteListener(
425  dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
426 
427  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
428  if (listener)
429  listener->notifyOperationInserted(op, previous);
430  }
431  void notifyBlockInserted(Block *block, Region *previous,
432  Region::iterator previousIt) override {
433  if (listener)
434  listener->notifyBlockInserted(block, previous, previousIt);
435  }
436  void notifyBlockErased(Block *block) override {
437  if (rewriteListener)
438  rewriteListener->notifyBlockErased(block);
439  }
440  void notifyOperationModified(Operation *op) override {
441  if (rewriteListener)
442  rewriteListener->notifyOperationModified(op);
443  }
444  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
445  if (rewriteListener)
446  rewriteListener->notifyOperationReplaced(op, newOp);
447  }
449  ValueRange replacement) override {
450  if (rewriteListener)
451  rewriteListener->notifyOperationReplaced(op, replacement);
452  }
453  void notifyOperationErased(Operation *op) override {
454  if (rewriteListener)
455  rewriteListener->notifyOperationErased(op);
456  }
457  void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
458  if (rewriteListener)
459  rewriteListener->notifyPatternBegin(pattern, op);
460  }
461  void notifyPatternEnd(const Pattern &pattern,
462  LogicalResult status) override {
463  if (rewriteListener)
464  rewriteListener->notifyPatternEnd(pattern, status);
465  }
467  Location loc,
468  function_ref<void(Diagnostic &)> reasonCallback) override {
469  if (rewriteListener)
470  rewriteListener->notifyMatchFailure(loc, reasonCallback);
471  }
472 
473  private:
474  OpBuilder::Listener *listener;
475  RewriterBase::Listener *rewriteListener;
476  };
477 
478  /// A listener that logs notification events to llvm::dbgs() before
479  /// forwarding to the base listener.
481  PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName)
482  : RewriterBase::ForwardingListener(listener), patternName(patternName) {
483  }
484 
485  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
486  void notifyOperationModified(Operation *op) override;
487  void notifyOperationReplaced(Operation *op, Operation *newOp) override;
489  ValueRange replacement) override;
490  void notifyOperationErased(Operation *op) override;
491  void notifyPatternBegin(const Pattern &pattern, Operation *op) override;
492 
493  private:
494  StringRef patternName;
495  };
496 
497  /// Move the blocks that belong to "region" before the given position in
498  /// another region "parent". The two regions must be different. The caller
499  /// is responsible for creating or updating the operation transferring flow
500  /// of control to the region and passing it the correct block arguments.
501  void inlineRegionBefore(Region &region, Region &parent,
502  Region::iterator before);
503  void inlineRegionBefore(Region &region, Block *before);
504 
505  /// Replace the results of the given (original) operation with the specified
506  /// list of values (replacements). The result types of the given op and the
507  /// replacements must match. The original op is erased.
508  virtual void replaceOp(Operation *op, ValueRange newValues);
509 
510  /// Replace the results of the given (original) operation with the specified
511  /// new op (replacement). The result types of the two ops must match. The
512  /// original op is erased.
513  virtual void replaceOp(Operation *op, Operation *newOp);
514 
515  /// Replace the results of the given (original) op with a new op that is
516  /// created without verification (replacement). The result values of the two
517  /// ops must match. The original op is erased.
518  template <typename OpTy, typename... Args>
519  OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
520  auto builder = static_cast<OpBuilder *>(this);
521  auto newOp =
522  OpTy::create(*builder, op->getLoc(), std::forward<Args>(args)...);
523  replaceOp(op, newOp.getOperation());
524  return newOp;
525  }
526 
527  /// This method erases an operation that is known to have no uses.
528  ///
529  /// If the current insertion point is before the erased operation, it is
530  /// adjusted to the following operation (or the end of the block). If the
531  /// current insertion point is within the erased operation, the insertion
532  /// point is left in an invalid state.
533  virtual void eraseOp(Operation *op);
534 
535  /// This method erases all operations in a block.
536  virtual void eraseBlock(Block *block);
537 
538  /// Inline the operations of block 'source' into block 'dest' before the given
539  /// position. The source block will be deleted and must have no uses.
540  /// 'argValues' is used to replace the block arguments of 'source'.
541  ///
542  /// If the source block is inserted at the end of the dest block, the dest
543  /// block must have no successors. Similarly, if the source block is inserted
544  /// somewhere in the middle (or beginning) of the dest block, the source block
545  /// must have no successors. Otherwise, the resulting IR would have
546  /// unreachable operations.
547  ///
548  /// If the insertion point is within the source block, it is adjusted to the
549  /// destination block.
550  virtual void inlineBlockBefore(Block *source, Block *dest,
551  Block::iterator before,
552  ValueRange argValues = {});
553 
554  /// Inline the operations of block 'source' before the operation 'op'. The
555  /// source block will be deleted and must have no uses. 'argValues' is used to
556  /// replace the block arguments of 'source'
557  ///
558  /// The source block must have no successors. Otherwise, the resulting IR
559  /// would have unreachable operations.
560  ///
561  /// If the insertion point is within the source block, it is adjusted to the
562  /// destination block.
563  void inlineBlockBefore(Block *source, Operation *op,
564  ValueRange argValues = {});
565 
566  /// Inline the operations of block 'source' into the end of block 'dest'. The
567  /// source block will be deleted and must have no uses. 'argValues' is used to
568  /// replace the block arguments of 'source'
569  ///
570  /// The dest block must have no successors. Otherwise, the resulting IR would
571  /// have unreachable operation.
572  ///
573  /// If the insertion point is within the source block, it is adjusted to the
574  /// destination block.
575  void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {});
576 
577  /// Split the operations starting at "before" (inclusive) out of the given
578  /// block into a new block, and return it.
579  Block *splitBlock(Block *block, Block::iterator before);
580 
581  /// Unlink this operation from its current block and insert it right before
582  /// `existingOp` which may be in the same or another block in the same
583  /// function.
584  void moveOpBefore(Operation *op, Operation *existingOp);
585 
586  /// Unlink this operation from its current block and insert it right before
587  /// `iterator` in the specified block.
588  void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
589 
590  /// Unlink this operation from its current block and insert it right after
591  /// `existingOp` which may be in the same or another block in the same
592  /// function.
593  void moveOpAfter(Operation *op, Operation *existingOp);
594 
595  /// Unlink this operation from its current block and insert it right after
596  /// `iterator` in the specified block.
597  void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
598 
599  /// Unlink this block and insert it right before `existingBlock`.
600  void moveBlockBefore(Block *block, Block *anotherBlock);
601 
602  /// Unlink this block and insert it right before the location that the given
603  /// iterator points to in the given region.
604  void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
605 
606  /// This method is used to notify the rewriter that an in-place operation
607  /// modification is about to happen. A call to this function *must* be
608  /// followed by a call to either `finalizeOpModification` or
609  /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
610  /// a new operation and removing the old one) but also often allows simpler
611  /// code in the client.
612  virtual void startOpModification(Operation *op) {}
613 
614  /// This method is used to signal the end of an in-place modification of the
615  /// given operation. This can only be called on operations that were provided
616  /// to a call to `startOpModification`.
617  virtual void finalizeOpModification(Operation *op);
618 
619  /// This method cancels a pending in-place modification. This can only be
620  /// called on operations that were provided to a call to
621  /// `startOpModification`.
622  virtual void cancelOpModification(Operation *op) {}
623 
624  /// This method is a utility wrapper around an in-place modification of an
625  /// operation. It wraps calls to `startOpModification` and
626  /// `finalizeOpModification` around the given callable.
627  template <typename CallableT>
628  void modifyOpInPlace(Operation *root, CallableT &&callable) {
629  startOpModification(root);
630  callable();
632  }
633 
634  /// Find uses of `from` and replace them with `to`. Also notify the listener
635  /// about every in-place op modification (for every use that was replaced).
636  virtual void replaceAllUsesWith(Value from, Value to) {
637  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
638  Operation *op = operand.getOwner();
639  modifyOpInPlace(op, [&]() { operand.set(to); });
640  }
641  }
642  void replaceAllUsesWith(Block *from, Block *to) {
643  for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
644  Operation *op = operand.getOwner();
645  modifyOpInPlace(op, [&]() { operand.set(to); });
646  }
647  }
649  assert(from.size() == to.size() && "incorrect number of replacements");
650  for (auto it : llvm::zip(from, to))
651  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
652  }
653 
654  /// Find uses of `from` and replace them with `to`. Also notify the listener
655  /// about every in-place op modification (for every use that was replaced)
656  /// and that the `from` operation is about to be replaced.
657  ///
658  /// Note: This function cannot be called `replaceAllUsesWith` because the
659  /// overload resolution, when called with an op that can be implicitly
660  /// converted to a Value, would be ambiguous.
661  void replaceAllOpUsesWith(Operation *from, ValueRange to);
662  void replaceAllOpUsesWith(Operation *from, Operation *to);
663 
664  /// Find uses of `from` and replace them with `to` if the `functor` returns
665  /// true. Also notify the listener about every in-place op modification (for
666  /// every use that was replaced). The optional `allUsesReplaced` flag is set
667  /// to "true" if all uses were replaced.
668  void replaceUsesWithIf(Value from, Value to,
669  function_ref<bool(OpOperand &)> functor,
670  bool *allUsesReplaced = nullptr);
672  function_ref<bool(OpOperand &)> functor,
673  bool *allUsesReplaced = nullptr);
674  // Note: This function cannot be called `replaceOpUsesWithIf` because the
675  // overload resolution, when called with an op that can be implicitly
676  // converted to a Value, would be ambiguous.
678  function_ref<bool(OpOperand &)> functor,
679  bool *allUsesReplaced = nullptr) {
680  replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
681  }
682 
683  /// Find uses of `from` within `block` and replace them with `to`. Also notify
684  /// the listener about every in-place op modification (for every use that was
685  /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
686  /// uses were replaced.
688  Block *block, bool *allUsesReplaced = nullptr) {
690  op, newValues,
691  [block](OpOperand &use) {
692  return block->getParentOp()->isProperAncestor(use.getOwner());
693  },
694  allUsesReplaced);
695  }
696 
697  /// Find uses of `from` and replace them with `to` except if the user is
698  /// `exceptedUser`. Also notify the listener about every in-place op
699  /// modification (for every use that was replaced).
700  void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
701  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
702  Operation *user = use.getOwner();
703  return user != exceptedUser;
704  });
705  }
706  void replaceAllUsesExcept(Value from, Value to,
707  const SmallPtrSetImpl<Operation *> &preservedUsers);
708 
709  /// Used to notify the listener that the IR failed to be rewritten because of
710  /// a match failure, and provide a callback to populate a diagnostic with the
711  /// reason why the failure occurred. This method allows for derived rewriters
712  /// to optionally hook into the reason why a rewrite failed, and display it to
713  /// users.
714  template <typename CallbackT>
715  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
716  notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
717  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
718  rewriteListener->notifyMatchFailure(
719  loc, function_ref<void(Diagnostic &)>(reasonCallback));
720  return failure();
721  }
722  template <typename CallbackT>
723  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
724  notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
725  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
726  rewriteListener->notifyMatchFailure(
727  op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
728  return failure();
729  }
730  template <typename ArgT>
731  LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
732  return notifyMatchFailure(std::forward<ArgT>(arg),
733  [&](Diagnostic &diag) { diag << msg; });
734  }
735  template <typename ArgT>
736  LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
737  return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
738  }
739 
740 protected:
741  /// Initialize the builder.
742  explicit RewriterBase(MLIRContext *ctx,
743  OpBuilder::Listener *listener = nullptr)
744  : OpBuilder(ctx, listener) {}
745  explicit RewriterBase(const OpBuilder &otherBuilder)
746  : OpBuilder(otherBuilder) {}
748  : OpBuilder(op, listener) {}
749  virtual ~RewriterBase();
750 
751 private:
752  void operator=(const RewriterBase &) = delete;
753  RewriterBase(const RewriterBase &) = delete;
754 };
755 
756 //===----------------------------------------------------------------------===//
757 // IRRewriter
758 //===----------------------------------------------------------------------===//
759 
760 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
761 /// providing a way to keep track of the mutations made to the IR. This class
762 /// should only be used in situations where another `RewriterBase` instance,
763 /// such as a `PatternRewriter`, is not available.
764 class IRRewriter : public RewriterBase {
765 public:
767  : RewriterBase(ctx, listener) {}
768  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
770  : RewriterBase(op, listener) {}
771 };
772 
773 //===----------------------------------------------------------------------===//
774 // PatternRewriter
775 //===----------------------------------------------------------------------===//
776 
777 /// A special type of `RewriterBase` that coordinates the application of a
778 /// rewrite pattern on the current IR being matched, providing a way to keep
779 /// track of any mutations made. This class should be used to perform all
780 /// necessary IR mutations within a rewrite pattern, as the pattern driver may
781 /// be tracking various state that would be invalidated when a mutation takes
782 /// place.
784 public:
785  explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
787 
788  /// A hook used to indicate if the pattern rewriter can recover from failure
789  /// during the rewrite stage of a pattern. For example, if the pattern
790  /// rewriter supports rollback, it may progress smoothly even if IR was
791  /// changed during the rewrite.
792  virtual bool canRecoverFromRewriteFailure() const { return false; }
793 };
794 
795 } // namespace mlir
796 
797 // Optionally expose PDL pattern matching methods.
798 #include "PDLPatternMatch.h.inc"
799 
800 namespace mlir {
801 
802 //===----------------------------------------------------------------------===//
803 // RewritePatternSet
804 //===----------------------------------------------------------------------===//
805 
807  using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
808 
809 public:
810  RewritePatternSet(MLIRContext *context) : context(context) {}
811 
812  /// Construct a RewritePatternSet populated with the given pattern.
814  std::unique_ptr<RewritePattern> pattern)
815  : context(context) {
816  nativePatterns.emplace_back(std::move(pattern));
817  }
818  RewritePatternSet(PDLPatternModule &&pattern)
819  : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
820 
821  MLIRContext *getContext() const { return context; }
822 
823  /// Return the native patterns held in this list.
824  NativePatternListT &getNativePatterns() { return nativePatterns; }
825 
826  /// Return the PDL patterns held in this list.
827  PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
828 
829  /// Clear out all of the held patterns in this list.
830  void clear() {
831  nativePatterns.clear();
832  pdlPatterns.clear();
833  }
834 
835  //===--------------------------------------------------------------------===//
836  // 'add' methods for adding patterns to the set.
837  //===--------------------------------------------------------------------===//
838 
839  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
840  /// the given arguments. Return a reference to `this` for chaining insertions.
841  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
842  template <typename... Ts, typename ConstructorArg,
843  typename... ConstructorArgs,
844  typename = std::enable_if_t<sizeof...(Ts) != 0>>
845  RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
846  // The following expands a call to emplace_back for each of the pattern
847  // types 'Ts'.
848  (addImpl<Ts>(/*debugLabels=*/{}, std::forward<ConstructorArg>(arg),
849  std::forward<ConstructorArgs>(args)...),
850  ...);
851  return *this;
852  }
853  /// An overload of the above `add` method that allows for attaching a set
854  /// of debug labels to the attached patterns. This is useful for labeling
855  /// groups of patterns that may be shared between multiple different
856  /// passes/users.
857  template <typename... Ts, typename ConstructorArg,
858  typename... ConstructorArgs,
859  typename = std::enable_if_t<sizeof...(Ts) != 0>>
861  ConstructorArg &&arg,
862  ConstructorArgs &&...args) {
863  // The following expands a call to emplace_back for each of the pattern
864  // types 'Ts'.
865  (addImpl<Ts>(debugLabels, arg, args...), ...);
866  return *this;
867  }
868 
869  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
870  /// `this` for chaining insertions.
871  template <typename... Ts>
873  (addImpl<Ts>(), ...);
874  return *this;
875  }
876 
877  /// Add the given native pattern to the pattern list. Return a reference to
878  /// `this` for chaining insertions.
879  RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
880  nativePatterns.emplace_back(std::move(pattern));
881  return *this;
882  }
883 
884  /// Add the given PDL pattern to the pattern list. Return a reference to
885  /// `this` for chaining insertions.
886  RewritePatternSet &add(PDLPatternModule &&pattern) {
887  pdlPatterns.mergeIn(std::move(pattern));
888  return *this;
889  }
890 
891  // Add a matchAndRewrite style pattern represented as a C function pointer.
892  template <typename OpType>
894  add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
895  PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
896  struct FnPattern final : public OpRewritePattern<OpType> {
897  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
898  MLIRContext *context, PatternBenefit benefit,
899  ArrayRef<StringRef> generatedNames)
900  : OpRewritePattern<OpType>(context, benefit, generatedNames),
901  implFn(implFn) {}
902 
903  LogicalResult matchAndRewrite(OpType op,
904  PatternRewriter &rewriter) const override {
905  return implFn(op, rewriter);
906  }
907 
908  private:
909  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
910  };
911  add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
912  generatedNames));
913  return *this;
914  }
915 
916  //===--------------------------------------------------------------------===//
917  // Pattern Insertion
918  //===--------------------------------------------------------------------===//
919 
920  // TODO: These are soft deprecated in favor of the 'add' methods above.
921 
922  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
923  /// the given arguments. Return a reference to `this` for chaining insertions.
924  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
925  template <typename... Ts, typename ConstructorArg,
926  typename... ConstructorArgs,
927  typename = std::enable_if_t<sizeof...(Ts) != 0>>
928  RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
929  // The following expands a call to emplace_back for each of the pattern
930  // types 'Ts'.
931  (addImpl<Ts>(/*debugLabels=*/{}, arg, args...), ...);
932  return *this;
933  }
934 
935  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
936  /// `this` for chaining insertions.
937  template <typename... Ts>
939  (addImpl<Ts>(), ...);
940  return *this;
941  }
942 
943  /// Add the given native pattern to the pattern list. Return a reference to
944  /// `this` for chaining insertions.
945  RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
946  nativePatterns.emplace_back(std::move(pattern));
947  return *this;
948  }
949 
950  /// Add the given PDL pattern to the pattern list. Return a reference to
951  /// `this` for chaining insertions.
952  RewritePatternSet &insert(PDLPatternModule &&pattern) {
953  pdlPatterns.mergeIn(std::move(pattern));
954  return *this;
955  }
956 
957  // Add a matchAndRewrite style pattern represented as a C function pointer.
958  template <typename OpType>
960  insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
961  struct FnPattern final : public OpRewritePattern<OpType> {
962  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
963  MLIRContext *context)
964  : OpRewritePattern<OpType>(context), implFn(implFn) {
965  this->setDebugName(llvm::getTypeName<FnPattern>());
966  }
967 
968  LogicalResult matchAndRewrite(OpType op,
969  PatternRewriter &rewriter) const override {
970  return implFn(op, rewriter);
971  }
972 
973  private:
974  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
975  };
976  add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
977  return *this;
978  }
979 
980 private:
981  /// Add an instance of the pattern type 'T'. Return a reference to `this` for
982  /// chaining insertions.
983  template <typename T, typename... Args>
984  std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
985  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
986  std::unique_ptr<T> pattern =
987  RewritePattern::create<T>(std::forward<Args>(args)...);
988  pattern->addDebugLabels(debugLabels);
989  nativePatterns.emplace_back(std::move(pattern));
990  }
991 
992  template <typename T, typename... Args>
993  std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
994  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
995  // TODO: Add the provided labels to the PDL pattern when PDL supports
996  // labels.
997  pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
998  }
999 
1000  MLIRContext *const context;
1001  NativePatternListT nativePatterns;
1002 
1003  // Patterns expressed with PDL. This will compile to a stub class when PDL is
1004  // not enabled.
1005  PDLPatternModule pdlPatterns;
1006 };
1007 
1008 } // namespace mlir
1009 
1010 #endif // MLIR_IR_PATTERNMATCH_H
static std::string diag(const llvm::Value &value)
A block operand represents an operand that holds a reference to a Block, e.g.
Definition: BlockSupport.h:30
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: UseDefLists.h:253
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:766
IRRewriter(const OpBuilder &builder)
Definition: PatternMatch.h:768
IRRewriter(Operation *op, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:769
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a saved insertion point.
Definition: Builders.h:327
This class helps build Operations.
Definition: Builders.h:207
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:610
This class represents an operand of an operation.
Definition: Value.h:257
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:342
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:344
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
PatternBenefit & operator=(const PatternBenefit &)=default
bool operator<(const PatternBenefit &rhs) const
Definition: PatternMatch.h:54
bool operator==(const PatternBenefit &rhs) const
Definition: PatternMatch.h:50
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
bool operator>=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:59
bool operator<=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:58
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
bool operator!=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:53
PatternBenefit()=default
bool operator>(const PatternBenefit &rhs) const
Definition: PatternMatch.h:57
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
PatternRewriter(MLIRContext *ctx)
Definition: PatternMatch.h:785
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:792
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Definition: PatternMatch.h:103
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:134
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
Definition: PatternMatch.h:147
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
Definition: PatternMatch.h:202
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:123
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
Definition: PatternMatch.h:112
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
Definition: PatternMatch.h:144
void addDebugLabels(StringRef label)
Definition: PatternMatch.h:153
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
Definition: PatternMatch.h:150
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:140
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType::iterator iterator
Definition: Region.h:52
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
Definition: PatternMatch.h:886
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Definition: PatternMatch.h:894
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
Definition: PatternMatch.h:824
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
Definition: PatternMatch.h:879
RewritePatternSet(PDLPatternModule &&pattern)
Definition: PatternMatch.h:818
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:928
MLIRContext * getContext() const
Definition: PatternMatch.h:821
void clear()
Clear out all of the held patterns in this list.
Definition: PatternMatch.h:830
RewritePatternSet(MLIRContext *context)
Definition: PatternMatch.h:810
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
Definition: PatternMatch.h:872
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
Definition: PatternMatch.h:813
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
Definition: PatternMatch.h:952
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
Definition: PatternMatch.h:945
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:827
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
Definition: PatternMatch.h:938
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
Definition: PatternMatch.h:860
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
Definition: PatternMatch.h:960
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
Definition: PatternMatch.h:254
virtual ~RewritePattern()=default
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const =0
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Definition: PatternMatch.h:724
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Definition: PatternMatch.h:677
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
Definition: PatternMatch.h:731
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
RewriterBase(const OpBuilder &otherBuilder)
Definition: PatternMatch.h:745
void replaceAllUsesWith(ValueRange from, ValueRange to)
Definition: PatternMatch.h:648
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
RewriterBase(Operation *op, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:747
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:622
void replaceAllUsesWith(Block *from, Block *to)
Definition: PatternMatch.h:642
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Definition: PatternMatch.h:742
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
Definition: PatternMatch.h:736
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Definition: PatternMatch.h:687
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:612
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:135
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Base class for listeners.
Definition: Builders.h:264
Kind
The kind of listener.
Definition: Builders.h:266
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:285
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:308
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:298
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:332
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:159
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:421
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:466
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Definition: PatternMatch.h:427
void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:461
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:440
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:457
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:444
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:453
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:436
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:422
void notifyOperationReplaced(Operation *op, ValueRange replacement) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:448
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notify the listener that the specified block was inserted.
Definition: PatternMatch.h:431
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:412
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:369
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:393
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:377
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:366
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:404
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:397
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:386
static bool classof(const OpBuilder::Listener *base)
A listener that logs notification events to llvm::dbgs() before forwarding to the base listener.
Definition: PatternMatch.h:480
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName)
Definition: PatternMatch.h:481
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
Definition: PatternMatch.h:293
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const =0
Method that operates on the SourceOp type.