MLIR  19.0.0git
PatternMatch.h
Go to the documentation of this file.
1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
16 #include <optional>
17 
18 namespace mlir {
19 
20 class PatternRewriter;
21 
22 //===----------------------------------------------------------------------===//
23 // PatternBenefit class
24 //===----------------------------------------------------------------------===//
25 
26 /// This class represents the benefit of a pattern match in a unitless scheme
27 /// that ranges from 0 (very little benefit) to 65K. The most common unit to
28 /// use here is the "number of operations matched" by the pattern.
29 ///
30 /// This also has a sentinel representation that can be used for patterns that
31 /// fail to match.
32 ///
34  enum { ImpossibleToMatchSentinel = 65535 };
35 
36 public:
37  PatternBenefit() = default;
38  PatternBenefit(unsigned benefit);
39  PatternBenefit(const PatternBenefit &) = default;
41 
43  bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
44 
45  /// If the corresponding pattern can match, return its benefit. If the
46  // corresponding pattern isImpossibleToMatch() then this aborts.
47  unsigned short getBenefit() const;
48 
49  bool operator==(const PatternBenefit &rhs) const {
50  return representation == rhs.representation;
51  }
52  bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
53  bool operator<(const PatternBenefit &rhs) const {
54  return representation < rhs.representation;
55  }
56  bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
57  bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
58  bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
59 
60 private:
61  unsigned short representation{ImpossibleToMatchSentinel};
62 };
63 
64 //===----------------------------------------------------------------------===//
65 // Pattern
66 //===----------------------------------------------------------------------===//
67 
68 /// This class contains all of the data related to a pattern, but does not
69 /// contain any methods or logic for the actual matching. This class is solely
70 /// used to interface with the metadata of a pattern, such as the benefit or
71 /// root operation.
72 class Pattern {
73  /// This enum represents the kind of value used to select the root operations
74  /// that match this pattern.
75  enum class RootKind {
76  /// The pattern root matches "any" operation.
77  Any,
78  /// The pattern root is matched using a concrete operation name.
80  /// The pattern root is matched using an interface ID.
81  InterfaceID,
82  /// The patter root is matched using a trait ID.
83  TraitID
84  };
85 
86 public:
87  /// Return a list of operations that may be generated when rewriting an
88  /// operation instance with this pattern.
89  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
90 
91  /// Return the root node that this pattern matches. Patterns that can match
92  /// multiple root types return std::nullopt.
93  std::optional<OperationName> getRootKind() const {
94  if (rootKind == RootKind::OperationName)
95  return OperationName::getFromOpaquePointer(rootValue);
96  return std::nullopt;
97  }
98 
99  /// Return the interface ID used to match the root operation of this pattern.
100  /// If the pattern does not use an interface ID for deciding the root match,
101  /// this returns std::nullopt.
102  std::optional<TypeID> getRootInterfaceID() const {
103  if (rootKind == RootKind::InterfaceID)
104  return TypeID::getFromOpaquePointer(rootValue);
105  return std::nullopt;
106  }
107 
108  /// Return the trait ID used to match the root operation of this pattern.
109  /// If the pattern does not use a trait ID for deciding the root match, this
110  /// returns std::nullopt.
111  std::optional<TypeID> getRootTraitID() const {
112  if (rootKind == RootKind::TraitID)
113  return TypeID::getFromOpaquePointer(rootValue);
114  return std::nullopt;
115  }
116 
117  /// Return the benefit (the inverse of "cost") of matching this pattern. The
118  /// benefit of a Pattern is always static - rewrites that may have dynamic
119  /// benefit can be instantiated multiple times (different Pattern instances)
120  /// for each benefit that they may return, and be guarded by different match
121  /// condition predicates.
122  PatternBenefit getBenefit() const { return benefit; }
123 
124  /// Returns true if this pattern is known to result in recursive application,
125  /// i.e. this pattern may generate IR that also matches this pattern, but is
126  /// known to bound the recursion. This signals to a rewrite driver that it is
127  /// safe to apply this pattern recursively to generated IR.
129  return contextAndHasBoundedRecursion.getInt();
130  }
131 
132  /// Return the MLIRContext used to create this pattern.
134  return contextAndHasBoundedRecursion.getPointer();
135  }
136 
137  /// Return a readable name for this pattern. This name should only be used for
138  /// debugging purposes, and may be empty.
139  StringRef getDebugName() const { return debugName; }
140 
141  /// Set the human readable debug name used for this pattern. This name will
142  /// only be used for debugging purposes.
143  void setDebugName(StringRef name) { debugName = name; }
144 
145  /// Return the set of debug labels attached to this pattern.
146  ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
147 
148  /// Add the provided debug labels to this pattern.
150  debugLabels.append(labels.begin(), labels.end());
151  }
152  void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
153 
154 protected:
155  /// This class acts as a special tag that makes the desire to match "any"
156  /// operation type explicit. This helps to avoid unnecessary usages of this
157  /// feature, and ensures that the user is making a conscious decision.
158  struct MatchAnyOpTypeTag {};
159  /// This class acts as a special tag that makes the desire to match any
160  /// operation that implements a given interface explicit. This helps to avoid
161  /// unnecessary usages of this feature, and ensures that the user is making a
162  /// conscious decision.
164  /// This class acts as a special tag that makes the desire to match any
165  /// operation that implements a given trait explicit. This helps to avoid
166  /// unnecessary usages of this feature, and ensures that the user is making a
167  /// conscious decision.
169 
170  /// Construct a pattern with a certain benefit that matches the operation
171  /// with the given root name.
172  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
173  ArrayRef<StringRef> generatedNames = {});
174  /// Construct a pattern that may match any operation type. `generatedNames`
175  /// contains the names of operations that may be generated during a successful
176  /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
177  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
178  /// always be supplied here.
179  Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
180  ArrayRef<StringRef> generatedNames = {});
181  /// Construct a pattern that may match any operation that implements the
182  /// interface defined by the provided `interfaceID`. `generatedNames` contains
183  /// the names of operations that may be generated during a successful rewrite.
184  /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
185  /// interface" behavior is what the user actually desired,
186  /// `MatchInterfaceOpTypeTag()` should always be supplied here.
187  Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
188  PatternBenefit benefit, MLIRContext *context,
189  ArrayRef<StringRef> generatedNames = {});
190  /// Construct a pattern that may match any operation that implements the
191  /// trait defined by the provided `traitID`. `generatedNames` contains the
192  /// names of operations that may be generated during a successful rewrite.
193  /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
194  /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
195  /// always be supplied here.
196  Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
197  MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
198 
199  /// Set the flag detailing if this pattern has bounded rewrite recursion or
200  /// not.
201  void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
202  contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
203  }
204 
205 private:
206  Pattern(const void *rootValue, RootKind rootKind,
207  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
208  MLIRContext *context);
209 
210  /// The value used to match the root operation of the pattern.
211  const void *rootValue;
212  RootKind rootKind;
213 
214  /// The expected benefit of matching this pattern.
215  const PatternBenefit benefit;
216 
217  /// The context this pattern was created from, and a boolean flag indicating
218  /// whether this pattern has bounded recursion or not.
219  llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
220 
221  /// A list of the potential operations that may be generated when rewriting
222  /// an op with this pattern.
223  SmallVector<OperationName, 2> generatedOps;
224 
225  /// A readable name for this pattern. May be empty.
226  StringRef debugName;
227 
228  /// The set of debug labels attached to this pattern.
229  SmallVector<StringRef, 0> debugLabels;
230 };
231 
232 //===----------------------------------------------------------------------===//
233 // RewritePattern
234 //===----------------------------------------------------------------------===//
235 
236 /// RewritePattern is the common base class for all DAG to DAG replacements.
237 /// There are two possible usages of this class:
238 /// * Multi-step RewritePattern with "match" and "rewrite"
239 /// - By overloading the "match" and "rewrite" functions, the user can
240 /// separate the concerns of matching and rewriting.
241 /// * Single-step RewritePattern with "matchAndRewrite"
242 /// - By overloading the "matchAndRewrite" function, the user can perform
243 /// the rewrite in the same call as the match.
244 ///
245 class RewritePattern : public Pattern {
246 public:
247  virtual ~RewritePattern() = default;
248 
249  /// Rewrite the IR rooted at the specified operation with the result of
250  /// this pattern, generating any new operations with the specified
251  /// builder. If an unexpected error is encountered (an internal
252  /// compiler error), it is emitted through the normal MLIR diagnostic
253  /// hooks and the IR is left in a valid state.
254  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
255 
256  /// Attempt to match against code rooted at the specified operation,
257  /// which is the same operation code as getRootKind().
258  virtual LogicalResult match(Operation *op) const;
259 
260  /// Attempt to match against code rooted at the specified operation,
261  /// which is the same operation code as getRootKind(). If successful, this
262  /// function will automatically perform the rewrite.
264  PatternRewriter &rewriter) const {
265  if (succeeded(match(op))) {
266  rewrite(op, rewriter);
267  return success();
268  }
269  return failure();
270  }
271 
272  /// This method provides a convenient interface for creating and initializing
273  /// derived rewrite patterns of the given type `T`.
274  template <typename T, typename... Args>
275  static std::unique_ptr<T> create(Args &&...args) {
276  std::unique_ptr<T> pattern =
277  std::make_unique<T>(std::forward<Args>(args)...);
278  initializePattern<T>(*pattern);
279 
280  // Set a default debug name if one wasn't provided.
281  if (pattern->getDebugName().empty())
282  pattern->setDebugName(llvm::getTypeName<T>());
283  return pattern;
284  }
285 
286 protected:
287  /// Inherit the base constructors from `Pattern`.
288  using Pattern::Pattern;
289 
290 private:
291  /// Trait to check if T provides a `getOperationName` method.
292  template <typename T, typename... Args>
293  using has_initialize = decltype(std::declval<T>().initialize());
294  template <typename T>
295  using detect_has_initialize = llvm::is_detected<has_initialize, T>;
296 
297  /// Initialize the derived pattern by calling its `initialize` method.
298  template <typename T>
299  static std::enable_if_t<detect_has_initialize<T>::value>
300  initializePattern(T &pattern) {
301  pattern.initialize();
302  }
303  /// Empty derived pattern initializer for patterns that do not have an
304  /// initialize method.
305  template <typename T>
306  static std::enable_if_t<!detect_has_initialize<T>::value>
307  initializePattern(T &) {}
308 
309  /// An anchor for the virtual table.
310  virtual void anchor();
311 };
312 
313 namespace detail {
314 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
315 /// allows for matching and rewriting against an instance of a derived operation
316 /// class or Interface.
317 template <typename SourceOp>
319  using RewritePattern::RewritePattern;
320 
321  /// Wrappers around the RewritePattern methods that pass the derived op type.
322  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
323  rewrite(cast<SourceOp>(op), rewriter);
324  }
325  LogicalResult match(Operation *op) const final {
326  return match(cast<SourceOp>(op));
327  }
329  PatternRewriter &rewriter) const final {
330  return matchAndRewrite(cast<SourceOp>(op), rewriter);
331  }
332 
333  /// Rewrite and Match methods that operate on the SourceOp type. These must be
334  /// overridden by the derived pattern class.
335  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
336  llvm_unreachable("must override rewrite or matchAndRewrite");
337  }
338  virtual LogicalResult match(SourceOp op) const {
339  llvm_unreachable("must override match or matchAndRewrite");
340  }
341  virtual LogicalResult matchAndRewrite(SourceOp op,
342  PatternRewriter &rewriter) const {
343  if (succeeded(match(op))) {
344  rewrite(op, rewriter);
345  return success();
346  }
347  return failure();
348  }
349 };
350 } // namespace detail
351 
352 /// OpRewritePattern is a wrapper around RewritePattern that allows for
353 /// matching and rewriting against an instance of a derived operation class as
354 /// opposed to a raw Operation.
355 template <typename SourceOp>
357  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
358  /// Patterns must specify the root operation name they match against, and can
359  /// also specify the benefit of the pattern matching and a list of generated
360  /// ops.
362  ArrayRef<StringRef> generatedNames = {})
364  SourceOp::getOperationName(), benefit, context, generatedNames) {}
365 };
366 
367 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
368 /// matching and rewriting against an instance of an operation interface instead
369 /// of a raw Operation.
370 template <typename SourceOp>
372  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
374  : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
375  Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
376  benefit, context) {}
377 };
378 
379 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
380 /// matching and rewriting against instances of an operation that possess a
381 /// given trait.
382 template <template <typename> class TraitType>
384 public:
387  benefit, context) {}
388 };
389 
390 //===----------------------------------------------------------------------===//
391 // RewriterBase
392 //===----------------------------------------------------------------------===//
393 
394 /// This class coordinates the application of a rewrite on a set of IR,
395 /// providing a way for clients to track mutations and create new operations.
396 /// This class serves as a common API for IR mutation between pattern rewrites
397 /// and non-pattern rewrites, and facilitates the development of shared
398 /// IR transformation utilities.
399 class RewriterBase : public OpBuilder {
400 public:
401  struct Listener : public OpBuilder::Listener {
403  : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
404 
405  /// Notify the listener that the specified block is about to be erased.
406  /// At this point, the block has zero uses.
407  virtual void notifyBlockErased(Block *block) {}
408 
409  /// Notify the listener that the specified operation was modified in-place.
410  virtual void notifyOperationModified(Operation *op) {}
411 
412  /// Notify the listener that the specified operation is about to be replaced
413  /// with another operation. This is called before the uses of the old
414  /// operation have been changed.
415  ///
416  /// By default, this function calls the "operation replaced with values"
417  /// notification.
419  Operation *replacement) {
420  notifyOperationReplaced(op, replacement->getResults());
421  }
422 
423  /// Notify the listener that the specified operation is about to be replaced
424  /// with the a range of values, potentially produced by other operations.
425  /// This is called before the uses of the operation have been changed.
427  ValueRange replacement) {}
428 
429  /// Notify the listener that the specified operation is about to be erased.
430  /// At this point, the operation has zero uses.
431  ///
432  /// Note: This notification is not triggered when unlinking an operation.
433  virtual void notifyOperationErased(Operation *op) {}
434 
435  /// Notify the listener that the specified pattern is about to be applied
436  /// at the specified root operation.
437  virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
438 
439  /// Notify the listener that a pattern application finished with the
440  /// specified status. "success" indicates that the pattern was applied
441  /// successfully. "failure" indicates that the pattern could not be
442  /// applied. The pattern may have communicated the reason for the failure
443  /// with `notifyMatchFailure`.
444  virtual void notifyPatternEnd(const Pattern &pattern,
445  LogicalResult status) {}
446 
447  /// Notify the listener that the pattern failed to match, and provide a
448  /// callback to populate a diagnostic with the reason why the failure
449  /// occurred. This method allows for derived listeners to optionally hook
450  /// into the reason why a rewrite failed, and display it to users.
451  virtual void
453  function_ref<void(Diagnostic &)> reasonCallback) {}
454 
455  static bool classof(const OpBuilder::Listener *base);
456  };
457 
458  /// A listener that forwards all notifications to another listener. This
459  /// struct can be used as a base to create listener chains, so that multiple
460  /// listeners can be notified of IR changes.
462  ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
463 
464  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
465  listener->notifyOperationInserted(op, previous);
466  }
467  void notifyBlockInserted(Block *block, Region *previous,
468  Region::iterator previousIt) override {
469  listener->notifyBlockInserted(block, previous, previousIt);
470  }
471  void notifyBlockErased(Block *block) override {
472  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
473  rewriteListener->notifyBlockErased(block);
474  }
475  void notifyOperationModified(Operation *op) override {
476  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
477  rewriteListener->notifyOperationModified(op);
478  }
479  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
480  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
481  rewriteListener->notifyOperationReplaced(op, newOp);
482  }
484  ValueRange replacement) override {
485  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
486  rewriteListener->notifyOperationReplaced(op, replacement);
487  }
488  void notifyOperationErased(Operation *op) override {
489  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
490  rewriteListener->notifyOperationErased(op);
491  }
492  void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
493  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
494  rewriteListener->notifyPatternBegin(pattern, op);
495  }
496  void notifyPatternEnd(const Pattern &pattern,
497  LogicalResult status) override {
498  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
499  rewriteListener->notifyPatternEnd(pattern, status);
500  }
502  Location loc,
503  function_ref<void(Diagnostic &)> reasonCallback) override {
504  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
505  rewriteListener->notifyMatchFailure(loc, reasonCallback);
506  }
507 
508  private:
509  OpBuilder::Listener *listener;
510  };
511 
512  /// Move the blocks that belong to "region" before the given position in
513  /// another region "parent". The two regions must be different. The caller
514  /// is responsible for creating or updating the operation transferring flow
515  /// of control to the region and passing it the correct block arguments.
516  void inlineRegionBefore(Region &region, Region &parent,
517  Region::iterator before);
518  void inlineRegionBefore(Region &region, Block *before);
519 
520  /// Replace the results of the given (original) operation with the specified
521  /// list of values (replacements). The result types of the given op and the
522  /// replacements must match. The original op is erased.
523  virtual void replaceOp(Operation *op, ValueRange newValues);
524 
525  /// Replace the results of the given (original) operation with the specified
526  /// new op (replacement). The result types of the two ops must match. The
527  /// original op is erased.
528  virtual void replaceOp(Operation *op, Operation *newOp);
529 
530  /// Replace the results of the given (original) op with a new op that is
531  /// created without verification (replacement). The result values of the two
532  /// ops must match. The original op is erased.
533  template <typename OpTy, typename... Args>
534  OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
535  auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
536  replaceOp(op, newOp.getOperation());
537  return newOp;
538  }
539 
540  /// This method erases an operation that is known to have no uses.
541  virtual void eraseOp(Operation *op);
542 
543  /// This method erases all operations in a block.
544  virtual void eraseBlock(Block *block);
545 
546  /// Inline the operations of block 'source' into block 'dest' before the given
547  /// position. The source block will be deleted and must have no uses.
548  /// 'argValues' is used to replace the block arguments of 'source'.
549  ///
550  /// If the source block is inserted at the end of the dest block, the dest
551  /// block must have no successors. Similarly, if the source block is inserted
552  /// somewhere in the middle (or beginning) of the dest block, the source block
553  /// must have no successors. Otherwise, the resulting IR would have
554  /// unreachable operations.
555  virtual void inlineBlockBefore(Block *source, Block *dest,
556  Block::iterator before,
557  ValueRange argValues = std::nullopt);
558 
559  /// Inline the operations of block 'source' before the operation 'op'. The
560  /// source block will be deleted and must have no uses. 'argValues' is used to
561  /// replace the block arguments of 'source'
562  ///
563  /// The source block must have no successors. Otherwise, the resulting IR
564  /// would have unreachable operations.
565  void inlineBlockBefore(Block *source, Operation *op,
566  ValueRange argValues = std::nullopt);
567 
568  /// Inline the operations of block 'source' into the end of block 'dest'. The
569  /// source block will be deleted and must have no uses. 'argValues' is used to
570  /// replace the block arguments of 'source'
571  ///
572  /// The dest block must have no successors. Otherwise, the resulting IR would
573  /// have unreachable operation.
574  void mergeBlocks(Block *source, Block *dest,
575  ValueRange argValues = std::nullopt);
576 
577  /// Split the operations starting at "before" (inclusive) out of the given
578  /// block into a new block, and return it.
579  Block *splitBlock(Block *block, Block::iterator before);
580 
581  /// Unlink this operation from its current block and insert it right before
582  /// `existingOp` which may be in the same or another block in the same
583  /// function.
584  void moveOpBefore(Operation *op, Operation *existingOp);
585 
586  /// Unlink this operation from its current block and insert it right before
587  /// `iterator` in the specified block.
588  void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
589 
590  /// Unlink this operation from its current block and insert it right after
591  /// `existingOp` which may be in the same or another block in the same
592  /// function.
593  void moveOpAfter(Operation *op, Operation *existingOp);
594 
595  /// Unlink this operation from its current block and insert it right after
596  /// `iterator` in the specified block.
597  void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
598 
599  /// Unlink this block and insert it right before `existingBlock`.
600  void moveBlockBefore(Block *block, Block *anotherBlock);
601 
602  /// Unlink this block and insert it right before the location that the given
603  /// iterator points to in the given region.
604  void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
605 
606  /// This method is used to notify the rewriter that an in-place operation
607  /// modification is about to happen. A call to this function *must* be
608  /// followed by a call to either `finalizeOpModification` or
609  /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
610  /// a new operation and removing the old one) but also often allows simpler
611  /// code in the client.
612  virtual void startOpModification(Operation *op) {}
613 
614  /// This method is used to signal the end of an in-place modification of the
615  /// given operation. This can only be called on operations that were provided
616  /// to a call to `startOpModification`.
617  virtual void finalizeOpModification(Operation *op);
618 
619  /// This method cancels a pending in-place modification. This can only be
620  /// called on operations that were provided to a call to
621  /// `startOpModification`.
622  virtual void cancelOpModification(Operation *op) {}
623 
624  /// This method is a utility wrapper around an in-place modification of an
625  /// operation. It wraps calls to `startOpModification` and
626  /// `finalizeOpModification` around the given callable.
627  template <typename CallableT>
628  void modifyOpInPlace(Operation *root, CallableT &&callable) {
629  startOpModification(root);
630  callable();
632  }
633 
634  /// Find uses of `from` and replace them with `to`. Also notify the listener
635  /// about every in-place op modification (for every use that was replaced).
636  void replaceAllUsesWith(Value from, Value to) {
637  return replaceAllUsesWith(from.getImpl(), to);
638  }
639  template <typename OperandType, typename ValueT>
641  for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
642  Operation *op = operand.getOwner();
643  modifyOpInPlace(op, [&]() { operand.set(to); });
644  }
645  }
647  assert(from.size() == to.size() && "incorrect number of replacements");
648  for (auto it : llvm::zip(from, to))
649  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
650  }
651  // Note: This function cannot be called `replaceAllUsesWith` because the
652  // overload resolution, when called with an op that can be implicitly
653  // converted to a Value, would be ambiguous.
655  replaceAllUsesWith(from->getResults(), to);
656  }
657 
658  /// Find uses of `from` and replace them with `to` if the `functor` returns
659  /// true. Also notify the listener about every in-place op modification (for
660  /// every use that was replaced). The optional `allUsesReplaced` flag is set
661  /// to "true" if all uses were replaced.
662  void replaceUsesWithIf(Value from, Value to,
663  function_ref<bool(OpOperand &)> functor,
664  bool *allUsesReplaced = nullptr);
666  function_ref<bool(OpOperand &)> functor,
667  bool *allUsesReplaced = nullptr);
668  // Note: This function cannot be called `replaceOpUsesWithIf` because the
669  // overload resolution, when called with an op that can be implicitly
670  // converted to a Value, would be ambiguous.
672  function_ref<bool(OpOperand &)> functor,
673  bool *allUsesReplaced = nullptr) {
674  replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
675  }
676 
677  /// Find uses of `from` within `block` and replace them with `to`. Also notify
678  /// the listener about every in-place op modification (for every use that was
679  /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
680  /// uses were replaced.
682  Block *block, bool *allUsesReplaced = nullptr) {
684  op, newValues,
685  [block](OpOperand &use) {
686  return block->getParentOp()->isProperAncestor(use.getOwner());
687  },
688  allUsesReplaced);
689  }
690 
691  /// Find uses of `from` and replace them with `to` except if the user is
692  /// `exceptedUser`. Also notify the listener about every in-place op
693  /// modification (for every use that was replaced).
694  void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
695  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
696  Operation *user = use.getOwner();
697  return user != exceptedUser;
698  });
699  }
700 
701  /// Used to notify the listener that the IR failed to be rewritten because of
702  /// a match failure, and provide a callback to populate a diagnostic with the
703  /// reason why the failure occurred. This method allows for derived rewriters
704  /// to optionally hook into the reason why a rewrite failed, and display it to
705  /// users.
706  template <typename CallbackT>
707  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
708  notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
709  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
710  rewriteListener->notifyMatchFailure(
711  loc, function_ref<void(Diagnostic &)>(reasonCallback));
712  return failure();
713  }
714  template <typename CallbackT>
715  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
716  notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
717  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
718  rewriteListener->notifyMatchFailure(
719  op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
720  return failure();
721  }
722  template <typename ArgT>
723  LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
724  return notifyMatchFailure(std::forward<ArgT>(arg),
725  [&](Diagnostic &diag) { diag << msg; });
726  }
727  template <typename ArgT>
728  LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
729  return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
730  }
731 
732 protected:
733  /// Initialize the builder.
734  explicit RewriterBase(MLIRContext *ctx,
735  OpBuilder::Listener *listener = nullptr)
736  : OpBuilder(ctx, listener) {}
737  explicit RewriterBase(const OpBuilder &otherBuilder)
738  : OpBuilder(otherBuilder) {}
740  : OpBuilder(op, listener) {}
741  virtual ~RewriterBase();
742 
743 private:
744  void operator=(const RewriterBase &) = delete;
745  RewriterBase(const RewriterBase &) = delete;
746 };
747 
748 //===----------------------------------------------------------------------===//
749 // IRRewriter
750 //===----------------------------------------------------------------------===//
751 
752 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
753 /// providing a way to keep track of the mutations made to the IR. This class
754 /// should only be used in situations where another `RewriterBase` instance,
755 /// such as a `PatternRewriter`, is not available.
756 class IRRewriter : public RewriterBase {
757 public:
759  : RewriterBase(ctx, listener) {}
760  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
762  : RewriterBase(op, listener) {}
763 };
764 
765 //===----------------------------------------------------------------------===//
766 // PatternRewriter
767 //===----------------------------------------------------------------------===//
768 
769 /// A special type of `RewriterBase` that coordinates the application of a
770 /// rewrite pattern on the current IR being matched, providing a way to keep
771 /// track of any mutations made. This class should be used to perform all
772 /// necessary IR mutations within a rewrite pattern, as the pattern driver may
773 /// be tracking various state that would be invalidated when a mutation takes
774 /// place.
776 public:
778 
779  /// A hook used to indicate if the pattern rewriter can recover from failure
780  /// during the rewrite stage of a pattern. For example, if the pattern
781  /// rewriter supports rollback, it may progress smoothly even if IR was
782  /// changed during the rewrite.
783  virtual bool canRecoverFromRewriteFailure() const { return false; }
784 };
785 
786 } // namespace mlir
787 
788 // Optionally expose PDL pattern matching methods.
789 #include "PDLPatternMatch.h.inc"
790 
791 namespace mlir {
792 
793 //===----------------------------------------------------------------------===//
794 // RewritePatternSet
795 //===----------------------------------------------------------------------===//
796 
798  using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
799 
800 public:
801  RewritePatternSet(MLIRContext *context) : context(context) {}
802 
803  /// Construct a RewritePatternSet populated with the given pattern.
805  std::unique_ptr<RewritePattern> pattern)
806  : context(context) {
807  nativePatterns.emplace_back(std::move(pattern));
808  }
809  RewritePatternSet(PDLPatternModule &&pattern)
810  : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
811 
812  MLIRContext *getContext() const { return context; }
813 
814  /// Return the native patterns held in this list.
815  NativePatternListT &getNativePatterns() { return nativePatterns; }
816 
817  /// Return the PDL patterns held in this list.
818  PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
819 
820  /// Clear out all of the held patterns in this list.
821  void clear() {
822  nativePatterns.clear();
823  pdlPatterns.clear();
824  }
825 
826  //===--------------------------------------------------------------------===//
827  // 'add' methods for adding patterns to the set.
828  //===--------------------------------------------------------------------===//
829 
830  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
831  /// the given arguments. Return a reference to `this` for chaining insertions.
832  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
833  template <typename... Ts, typename ConstructorArg,
834  typename... ConstructorArgs,
835  typename = std::enable_if_t<sizeof...(Ts) != 0>>
836  RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
837  // The following expands a call to emplace_back for each of the pattern
838  // types 'Ts'.
839  (addImpl<Ts>(/*debugLabels=*/std::nullopt,
840  std::forward<ConstructorArg>(arg),
841  std::forward<ConstructorArgs>(args)...),
842  ...);
843  return *this;
844  }
845  /// An overload of the above `add` method that allows for attaching a set
846  /// of debug labels to the attached patterns. This is useful for labeling
847  /// groups of patterns that may be shared between multiple different
848  /// passes/users.
849  template <typename... Ts, typename ConstructorArg,
850  typename... ConstructorArgs,
851  typename = std::enable_if_t<sizeof...(Ts) != 0>>
853  ConstructorArg &&arg,
854  ConstructorArgs &&...args) {
855  // The following expands a call to emplace_back for each of the pattern
856  // types 'Ts'.
857  (addImpl<Ts>(debugLabels, arg, args...), ...);
858  return *this;
859  }
860 
861  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
862  /// `this` for chaining insertions.
863  template <typename... Ts>
865  (addImpl<Ts>(), ...);
866  return *this;
867  }
868 
869  /// Add the given native pattern to the pattern list. Return a reference to
870  /// `this` for chaining insertions.
871  RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
872  nativePatterns.emplace_back(std::move(pattern));
873  return *this;
874  }
875 
876  /// Add the given PDL pattern to the pattern list. Return a reference to
877  /// `this` for chaining insertions.
878  RewritePatternSet &add(PDLPatternModule &&pattern) {
879  pdlPatterns.mergeIn(std::move(pattern));
880  return *this;
881  }
882 
883  // Add a matchAndRewrite style pattern represented as a C function pointer.
884  template <typename OpType>
886  add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
887  PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
888  struct FnPattern final : public OpRewritePattern<OpType> {
889  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
890  MLIRContext *context, PatternBenefit benefit,
891  ArrayRef<StringRef> generatedNames)
892  : OpRewritePattern<OpType>(context, benefit, generatedNames),
893  implFn(implFn) {}
894 
895  LogicalResult matchAndRewrite(OpType op,
896  PatternRewriter &rewriter) const override {
897  return implFn(op, rewriter);
898  }
899 
900  private:
901  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
902  };
903  add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
904  generatedNames));
905  return *this;
906  }
907 
908  //===--------------------------------------------------------------------===//
909  // Pattern Insertion
910  //===--------------------------------------------------------------------===//
911 
912  // TODO: These are soft deprecated in favor of the 'add' methods above.
913 
914  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
915  /// the given arguments. Return a reference to `this` for chaining insertions.
916  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
917  template <typename... Ts, typename ConstructorArg,
918  typename... ConstructorArgs,
919  typename = std::enable_if_t<sizeof...(Ts) != 0>>
920  RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
921  // The following expands a call to emplace_back for each of the pattern
922  // types 'Ts'.
923  (addImpl<Ts>(/*debugLabels=*/std::nullopt, arg, args...), ...);
924  return *this;
925  }
926 
927  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
928  /// `this` for chaining insertions.
929  template <typename... Ts>
931  (addImpl<Ts>(), ...);
932  return *this;
933  }
934 
935  /// Add the given native pattern to the pattern list. Return a reference to
936  /// `this` for chaining insertions.
937  RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
938  nativePatterns.emplace_back(std::move(pattern));
939  return *this;
940  }
941 
942  /// Add the given PDL pattern to the pattern list. Return a reference to
943  /// `this` for chaining insertions.
944  RewritePatternSet &insert(PDLPatternModule &&pattern) {
945  pdlPatterns.mergeIn(std::move(pattern));
946  return *this;
947  }
948 
949  // Add a matchAndRewrite style pattern represented as a C function pointer.
950  template <typename OpType>
952  insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
953  struct FnPattern final : public OpRewritePattern<OpType> {
954  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
955  MLIRContext *context)
956  : OpRewritePattern<OpType>(context), implFn(implFn) {
957  this->setDebugName(llvm::getTypeName<FnPattern>());
958  }
959 
960  LogicalResult matchAndRewrite(OpType op,
961  PatternRewriter &rewriter) const override {
962  return implFn(op, rewriter);
963  }
964 
965  private:
966  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
967  };
968  add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
969  return *this;
970  }
971 
972 private:
973  /// Add an instance of the pattern type 'T'. Return a reference to `this` for
974  /// chaining insertions.
975  template <typename T, typename... Args>
976  std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
977  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
978  std::unique_ptr<T> pattern =
979  RewritePattern::create<T>(std::forward<Args>(args)...);
980  pattern->addDebugLabels(debugLabels);
981  nativePatterns.emplace_back(std::move(pattern));
982  }
983 
984  template <typename T, typename... Args>
985  std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
986  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
987  // TODO: Add the provided labels to the PDL pattern when PDL supports
988  // labels.
989  pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
990  }
991 
992  MLIRContext *const context;
993  NativePatternListT nativePatterns;
994 
995  // Patterns expressed with PDL. This will compile to a stub class when PDL is
996  // not enabled.
997  PDLPatternModule pdlPatterns;
998 };
999 
1000 } // namespace mlir
1001 
1002 #endif // MLIR_IR_PATTERNMATCH_H
static std::string diag(const llvm::Value &value)
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class represents a single IR object that contains a use list.
Definition: UseDefLists.h:195
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: UseDefLists.h:253
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:756
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:758
IRRewriter(const OpBuilder &builder)
Definition: PatternMatch.h:760
IRRewriter(Operation *op, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:761
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a saved insertion point.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:209
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:601
This class represents an operand of an operation.
Definition: Value.h:263
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:383
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:385
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
PatternBenefit & operator=(const PatternBenefit &)=default
bool operator<(const PatternBenefit &rhs) const
Definition: PatternMatch.h:53
bool operator==(const PatternBenefit &rhs) const
Definition: PatternMatch.h:49
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:42
bool operator>=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:58
bool operator<=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:57
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
Definition: PatternMatch.h:43
bool operator!=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:52
PatternBenefit()=default
bool operator>(const PatternBenefit &rhs) const
Definition: PatternMatch.h:56
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:783
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Definition: PatternMatch.h:102
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:128
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:93
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:133
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
Definition: PatternMatch.h:146
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
Definition: PatternMatch.h:201
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:89
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:122
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
Definition: PatternMatch.h:111
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
Definition: PatternMatch.h:143
void addDebugLabels(StringRef label)
Definition: PatternMatch.h:152
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
Definition: PatternMatch.h:149
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:139
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType::iterator iterator
Definition: Region.h:52
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
Definition: PatternMatch.h:878
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Definition: PatternMatch.h:886
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
Definition: PatternMatch.h:815
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
Definition: PatternMatch.h:871
RewritePatternSet(PDLPatternModule &&pattern)
Definition: PatternMatch.h:809
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:920
MLIRContext * getContext() const
Definition: PatternMatch.h:812
void clear()
Clear out all of the held patterns in this list.
Definition: PatternMatch.h:821
RewritePatternSet(MLIRContext *context)
Definition: PatternMatch.h:801
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
Definition: PatternMatch.h:864
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
Definition: PatternMatch.h:804
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
Definition: PatternMatch.h:944
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
Definition: PatternMatch.h:937
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:818
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
Definition: PatternMatch.h:930
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
Definition: PatternMatch.h:852
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
Definition: PatternMatch.h:952
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
Definition: PatternMatch.h:275
virtual ~RewritePattern()=default
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:263
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Definition: PatternMatch.h:716
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Definition: PatternMatch.h:671
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
Definition: PatternMatch.h:723
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
RewriterBase(const OpBuilder &otherBuilder)
Definition: PatternMatch.h:737
void replaceAllUsesWith(ValueRange from, ValueRange to)
Definition: PatternMatch.h:646
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
RewriterBase(Operation *op, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:739
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:622
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:694
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Definition: PatternMatch.h:734
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
Definition: PatternMatch.h:728
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Definition: PatternMatch.h:654
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Definition: PatternMatch.h:681
void replaceAllUsesWith(IRObjectWithUseList< OperandType > *from, ValueT &&to)
Definition: PatternMatch.h:640
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:612
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:132
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
detail::ValueImpl * getImpl() const
Definition: Value.h:243
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Base class for listeners.
Definition: Builders.h:266
Kind
The kind of listener.
Definition: Builders.h:268
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:287
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:300
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:158
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:163
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:168
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:461
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:501
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Definition: PatternMatch.h:464
void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:496
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:475
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:492
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that the specified operation is about to be replaced with another operation.
Definition: PatternMatch.h:479
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:488
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:471
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:462
void notifyOperationReplaced(Operation *op, ValueRange replacement) override
Notify the listener that the specified operation is about to be replaced with the a range of values,...
Definition: PatternMatch.h:483
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notify the listener that the specified block was inserted.
Definition: PatternMatch.h:467
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:452
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:410
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:433
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that the specified operation is about to be replaced with another operation.
Definition: PatternMatch.h:418
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:407
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:444
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:437
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that the specified operation is about to be replaced with the a range of values,...
Definition: PatternMatch.h:426
static bool classof(const OpBuilder::Listener *base)
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
Definition: PatternMatch.h:318
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: PatternMatch.h:335
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: PatternMatch.h:322
virtual LogicalResult match(SourceOp op) const
Definition: PatternMatch.h:338
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:325
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
Definition: PatternMatch.h:341
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:328